Source code for pyllars.mpl_utils

"""
This module contains a number of helper functions for matplotlib.

For details about various arguments, such as allowed key word
arguments and how they will be interpreted, please consult the
appropriate parts of the matplotlib documentation:

* **Lines**: https://matplotlib.org/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D
* **Patches**: https://matplotlib.org/api/_as_gen/matplotlib.patches.Patch.html#matplotlib.patches.Patch
* **Scatter plots**: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter
* **Text**: https://matplotlib.org/api/text_api.html#matplotlib.text.Text

"""
import argparse
import itertools

import matplotlib
import matplotlib.colors
import matplotlib.pyplot
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import scipy
import sklearn.metrics

import matplotlib_venn

import typing
from typing import Collection, Iterable, Mapping, Optional, Sequence, Tuple, Union

BarChartColorOptions = Union[
    matplotlib.colors.Colormap,
    Sequence,
    int,
    str
]
IntOrString = Union[int, str]
FigAx = Tuple[matplotlib.figure.Figure, plt.Axes]
MapOrSequence = Union[Mapping,Sequence]

import pyllars.utils as utils
import pyllars.validation_utils as validation_utils

import logging
logger = logging.getLogger(__name__)

###
# Constants
###

VALID_AXIS_VALUES = {'both', 'x', 'y'}
"""Valid `axis` values"""

VALID_WHICH_VALUES = {'major', 'minor', 'both'}
"""Valid `which` values"""

X_AXIS_VALUES = {'both', 'x'}
"""`axis` choices which affect the X axis"""

Y_AXIS_VALUES = {'both', 'y'}
"""`axis` choices which affect the Y axis"""

[docs]def _get_fig_ax(ax:Optional[plt.Axes]): """ Grab a figure and axis from `ax`, or create a new one """ if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() return fig, ax
### # Font helpers ###
[docs]def set_legend_title_fontsize( ax:plt.Axes, fontsize:IntOrString) -> None: """ Set the font size of the title of the legend. Parameters ---------- ax : matplotlib.axes.Axes The axis fontsize : int, or a str recognized by matplotlib The size of the legend title Returns ------- None, but the legend title fontsize is updated """ legend = ax.legend_ plt.setp(legend.get_title(),fontsize=fontsize)
[docs]def set_legend_fontsize( ax:plt.Axes, fontsize:IntOrString) -> None: """ Set the font size of the items of the legend. Parameters ---------- ax : matplotlib.axes.Axes The axis fontsize : int, or a str recognized by matplotlib The size of the legend text Returns ------- None, but the legend text fontsize is updated """ legend = ax.legend_ plt.setp(legend.get_texts(),fontsize=fontsize)
[docs]def set_title_fontsize( ax:plt.Axes, fontsize:IntOrString) -> None: """ Set the font size of the title of the axis. Parameters ---------- ax : matplotlib.axes.Axes The axis fontsize : int, or a str recognized by matplotlib The size of the title font Returns ------- None, but the title fontsize is updated """ ax.title.set_fontsize(fontsize=fontsize)
[docs]def set_label_fontsize( ax:plt.Axes, fontsize:IntOrString, axis:str='both') -> None: """ Set the font size of the labels of the axis. Parameters ---------- ax : matplotlib.axes.Axes The axis fontsize : int, or a str recognized by matplotlib The size of the label font axis : str in {`both`, `x`, `y`} Which label(s) to update Returns ------- None, but the respective label fontsizes are updated """ validation_utils.validate_in_set(axis, VALID_AXIS_VALUES, "axis") if (axis == 'both') or (axis=='x'): l = ax.xaxis.label l.set_fontsize(fontsize) if (axis == 'both') or (axis=='y'): l = ax.yaxis.label l.set_fontsize(fontsize)
[docs]def set_ticklabels_fontsize( ax:plt.Axes, fontsize:IntOrString, axis:str='both', which:str='major'): """ Set the font size of the tick labels Parameters ---------- ax : matplotlib.axes.Axes The axis fontsize : int, or a str recognized by matplotlib The size of the ticklabels {axis,which} : str Values passed to :meth:`matplotlib.axes.Axes.tick_params`. Please see the matplotlib documentation for more details. Returns ------- None, but the ticklabel fontsizes are updated """ validation_utils.validate_in_set(axis, VALID_AXIS_VALUES, "axis") validation_utils.validate_in_set(which, VALID_WHICH_VALUES, "which") ax.tick_params(axis=axis, which=which, labelsize=fontsize)
[docs]def set_ticklabel_rotation( ax:plt.Axes, rotation:IntOrString, axis:str='x', which:str='both'): """ Set the rotation of the tick labels Parameters ---------- ax : matplotlib.axes.Axes The axis rotation : int, or a string matplotlib recognizes The rotation of the labels {axis,which} : str Values passed to :func:`matplotlib.pyplot.setp`. Please see the matplotlib documentation for more details. Returns ------- None, but the ticklabels are rotated """ validation_utils.validate_in_set(axis, VALID_AXIS_VALUES, "axis") validation_utils.validate_in_set(which, VALID_WHICH_VALUES, "which") adjust_xaxis = (axis == 'x') or (axis == 'both') adjust_yaxis = (axis == 'y') or (axis == 'both') adjust_major = (which == 'major') or (which == 'both') adjust_minor = (which == 'minor') or (which == 'both') if adjust_xaxis: xticklabels = [] if adjust_major: xticklabels.extend(ax.xaxis.get_majorticklabels()) if adjust_minor: xticklabels.extend(ax.xaxis.get_minorticklabels()) plt.setp(xticklabels, rotation=rotation) if adjust_yaxis: yticklabels = [] if adjust_major: yticklabels.extend(ax.yaxis.get_majorticklabels()) if adjust_minor: yticklabels.extend(ax.yaxis.get_minorticklabels()) plt.setp(yticklabels, rotation=rotation)
### # Axes helpers ###
[docs]def center_splines(ax:plt.Axes) -> None: """ Places the splines of `ax` in the center of the plot. This is useful for things like scatter plots where (0,0) should be in the center of the plot. Parameters ---------- ax : matplotlib.axes.Axes The axis Returns ------- None, but the splines are updated """ ax.spines['left'].set_position('zero') ax.spines['right'].set_color('none') ax.spines['bottom'].set_position('zero') ax.spines['top'].set_color('none') #ax.spines['left'].set_smart_bounds(True) #ax.spines['bottom'].set_smart_bounds(True) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') ax.xaxis.set_label_coords(0.5, 0) ax.yaxis.set_label_coords(-0.05, 0.5)
[docs]def hide_tick_labels( ax:plt.Axes, axis:str='both') -> None: """ Hide the tick labels on the specified axes. Optionally, some can be preserved. Parameters ---------- ax : matplotlib.axes.Axes The axis axis : str in {`both`, `x`, `y`} Axis of the tick labels to hide Returns ------- None, but the tick labels of the axis are removed, as specified """ hide_tick_labels_by_index(ax, axis=axis)
[docs]def hide_first_y_tick_label(ax:plt.Axes) -> None: """ Hide the first tick label on the y-axis Parameters ---------- ax : matplotlib.axes.Axes The axis Returns ------- None, but the tick label is hidden """ yticks = ax.yaxis.get_major_ticks() yticks[0].label1.set_visible(False)
[docs]def hide_tick_labels_by_text( ax:plt.Axes, to_remove_x:Collection=set(), to_remove_y:Collection=set()) -> None: """ Hide tick labels which match the given values. Parameters ---------- ax : matplotlib.axes.Axes The axis to_remove_{x,y}: typing.Collection[str] The values to remove Returns ------- None, but the specified tick labels are hidden """ xticks = ax.xaxis.get_major_ticks() num_xticks = len(xticks) keep_x = [i for i in range(num_xticks) if xticks[i].label1.get_text() not in to_remove_x] yticks = ax.yaxis.get_major_ticks() num_yticks = len(yticks) keep_y = [i for i in range(num_yticks) if yticks[i].label1.get_text() not in to_remove_y] hide_tick_labels_by_index(ax, keep_x=keep_x, keep_y=keep_y)
[docs]def hide_tick_labels_by_index( ax:plt.Axes, keep_x:Collection=set(), keep_y:Collection=set(), axis:str='both') -> None: """ Hide the tick labels on both axes. Optionally, some can be preserved. Parameters ---------- ax : matplotlib.axes.Axes The axis keep_{x,y} : typing.Collection[int] The indices of any x-axis ticks to keep. The numbers are passed directly as indices to the "ticks" arrays. axis : str in {`both`, `x`, `y`} Axis of the tick labels to hide Returns ------- None, but the tick labels of the axis are removed, as specified """ validation_utils.validate_in_set(axis, VALID_AXIS_VALUES, "axis") if axis in X_AXIS_VALUES: xticks = ax.xaxis.get_major_ticks() for xtick in xticks: xtick.label1.set_visible(False) for x in keep_x: xticks[x].label1.set_visible(True) if axis in Y_AXIS_VALUES: yticks = ax.yaxis.get_major_ticks() for ytick in yticks: ytick.label1.set_visible(False) for y in keep_y: yticks[y].label1.set_visible(True)
### # Standard, generic plot helpers ###
[docs]def plot_simple_bar_chart( bars:Sequence[Sequence[float]], ax:Optional[plt.Axes]=None, labels:Optional[Sequence[str]]=None, colors:BarChartColorOptions=plt.cm.Blues, xticklabels:Optional[Union[str,Sequence[str]]]='default', xticklabels_rotation:IntOrString='vertical', xlabel:Optional[str]=None, ylabel:Optional[str]=None, spacing:float=0, ymin:Optional[float]=None, ymax:Optional[float]=None, use_log_scale:bool=False, hide_first_ytick:bool=True, show_legend:bool=False, title:Optional[str]=None, tick_fontsize:int=12, label_fontsize:int=12, legend_fontsize:int=12, title_fontsize:int=12, tick_offset:float=0.5): """ Plot a simple bar chart based on the values in `bars` Parameters ----------- bars : typing.Sequence[typing.Sequence[float]] The heights of each bar. The "outer" sequence corresponds to each clustered group of bars, while the "inner" sequence gives the heights of each bar within the group. As a data science example, the "outer" groups may correspond to different datasets, while the "inner" group corresponds to different methods. ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. labels : typing.Optional[typing.Sequence[str]] The label for each "outer" group in `bars` colors : BarChartColorOptions The colors of the bars for each "inner" group. The options and their interpretations are: * color map : the color of each bar will be taken as equi-distant colors sampled from the map. For example, if there are three bars in thei nner group, then the colors will be: `colors(0.0)`, `colors(0.5)`, and `colors(1.0)`. * sequence of colors : the color of each bar will be taken from the respective position in the sequence. * scalar (int or str) : all bars will use this color xticklabels : typing.Optional[typing.Union[str,typing.Sequence[str]]] The tick labels for the "outer" groups. The options and their interpretations are: * None : no tick labels will be shown * "default" : the tick labels will be the numeric tick positions * sequence of strings : the tick labels will be the respective strings xticklabels_rotation : typing.Union[str,int] The rotation for the `xticklabels`. If a string is given, it should be something which matplotlib can interpret as a rotation. {x,y}label : typing.Optional[str] Labels for the respective axes spacing : float The distance on the x axis between the "outer" groups. y{min,max} : typing.Optional[float] The min and max for the y axis. If not given, the default min is 0 (or 1 if a logarithmic scale is used, see option below), and the default max is 2 times the height of the highest bar in any group. use_log_scale : bool Whether to use a normal or logarithmic scale for the y axis hide_first_ytick : bool Whether to hide the first tick mark and label on the y axis. Typically, the first tick mark is either 0 or 1 (depending on the scale of the y axis). This can be distracting to see, so the default is to hide it. show_legend : bool Whether to show the legend title : typing.Optional[str] A title for the axis {tick,label,legend,title}_fontsize : int The font size for the respective elements tick_offset : float The offset of the tick mark and label for the outer groups on the x axis Returns ------- fig : matplotlib.figure.Figure The figure on which the bars were plotted ax : matplotlib.axes.Axes The axis on which the bars were plotted """ fig, ax = _get_fig_ax(ax) mpl_bars = [] # first, handle the bars # TODO: check that the bar arrays are all the same length xticks = np.arange(len(bars[0])) width = 1 - 2*spacing width /= len(bars) if isinstance(colors, matplotlib.colors.Colormap): # then use "num_bars" equi-distant colors ls = np.linspace(0, 1, len(bars)) color_vals = [colors(c) for c in ls] colors = color_vals elif validation_utils.validate_is_sequence(colors, raise_on_invalid=False): # make sure this is the correct size if len(colors) != len(bars): msg = ("The number of colors ({}) and the number of bars({}) does " "not match.".format(len(colors), len(bars))) raise ValueError(msg) else: # we assume color is a scalar, and we will use the same color # for all bars colors = [colors] * len(bars) if labels is None: labels = np.full(len(bars), "", dtype=object) for i, bar in enumerate(bars): xpos = xticks + i*width if len(bar) < len(xpos): xpos = xpos[:len(bar)] mpl_bar = ax.bar(xpos, bar, width=width, color=colors[i], label=labels[i]) mpl_bars.append(mpl_bar) # now the x-axis if isinstance(xticklabels, str): if xticklabels == "default": xticklabels = xticks tick_offset = tick_offset - spacing if xticklabels is not None: ax.set_xticks(xticks+tick_offset) ax.set_xticklabels( xticklabels, fontsize=tick_fontsize, rotation=xticklabels_rotation ) else: ax.tick_params( axis='x', which='both', bottom='off', top='off', labelbottom='off' ) ax.set_xlim((-width, len(xticks)+width/2)) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=label_fontsize) # and the y-axis if use_log_scale: ax.set_yscale('log') if ymin is None: ymin = 0 if use_log_scale: ymin=1 if ymax is None: ymax = 2*max(max(x) for x in bars) ax.set_ylim((ymin, ymax)) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=label_fontsize) if hide_first_ytick: yticks = ax.yaxis.get_major_ticks() yticks[0].label1.set_visible(False) # and the legend if show_legend: ax.legend(fontsize=legend_fontsize) # and the title if title is not None: ax.set_title(title, fontsize=title_fontsize) return fig, ax
[docs]def plot_simple_scatter( x:Sequence[float], y:Sequence[float], ax:Optional[plt.Axes]=None, equal_aspect:bool=True, set_lim:bool=True, show_y_x_line:bool=True, xy_line_kwargs:dict={}, **kwargs)->FigAx: """ Plot a simple scatter plot of `x` vs. `y` on `ax` See the matplotlib documentation for more keyword arguments and details: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.scatter.html#matplotlib.pyplot.scatter Parameters ---------- {x,y} : typing.Sequence[float] The values to plot ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. equal_aspect : bool Whether to set the aspect of the axis to `equal` set_lim : bool Whether to automatically set the min and max axis limits show_y_x_line : bool Whether to draw the y=x line. This will look weird if `set_lim` is False. xy_line_kwargs : typing.Mapping keyword arguments for plotting the y=x line, if it plotting **kwargs : <key>=<value> pairs Additional keyword arguments to pass to the scatter function. Some useful keyword arguments are: * `label` : the label for a legend * `marker` : https://matplotlib.org/examples/lines_bars_and_markers/marker_reference.html Returns ------- fig : matplotlib.figure.Figure The figure on which the scatter points were plotted ax : matplotlib.axes.Axes The axis on which the scatter points were plotted """ fig, ax = _get_fig_ax(ax) ax.scatter(x,y, **kwargs) min_val = min(min(x), min(y)) max_val = max(max(x), max(y)) lim = (min_val, max_val) if set_lim: ax.set_xlim(lim) ax.set_ylim(lim) if show_y_x_line: ax.plot(lim, lim, **xy_line_kwargs) if equal_aspect: ax.set_aspect('equal') return fig, ax
[docs]def plot_stacked_bar_graph( ax, # axes to plot onto data, # data to plot colors=plt.cm.Blues, # color map for each level or list of colors x_tick_labels = None, # bar specific labels stack_labels=None, # the text for the legend y_ticks = None, # information used for making y ticks y_tick_labels=None, hide_first_ytick=True, edge_colors=None, # colors for edges showFirst=-1, # only plot the first <showFirst> bars scale=False, # scale bars to same height widths=None, # set widths for each bar heights=None, # set heights for each bar y_title=None, # label for x axis x_title=None, # label for y axis gap=0., # gap between bars end_gaps=False, # allow gaps at end of bar chart (only used if gaps != 0.) show_legend=True, # whether to show the legend legend_loc="best", # if using a legend, its location legend_bbox_to_anchor=None, # for the legend legend_ncol=-1, # for the legend log=False, # whether to use a log scale font_size=8, # the font size to use for the tick labels label_font_size=12, # the font size for the labels legend_font_size=8 ): """ Create a stacked bar plot with the given characteristics. This code is adapted from code by Michael Imelfort. """ #------------------------------------------------------------------------------ # data fixeratering # make sure this makes sense if showFirst != -1: showFirst = np.min([showFirst, np.shape(data)[0]]) data_copy = np.copy(data[:showFirst]).transpose().astype('float') data_shape = np.shape(data_copy) if heights is not None: heights = heights[:showFirst] if widths is not None: widths = widths[:showFirst] showFirst = -1 else: data_copy = np.copy(data).transpose() data_shape = np.shape(data_copy) # determine the number of bars and corresponding levels from the shape of the data num_bars = data_shape[1] levels = data_shape[0] if widths is None: widths = np.array([1] * num_bars) x = np.arange(num_bars) else: if not validation_utils.validate_is_sequence(widths, raise_on_invalid=False): widths = np.full(num_bars, widths) print("widths: ", widths) x = [0] for i in range(1, len(widths)): #x.append(x[i-1] + (widths[i-1] + widths[i])/2) x.append(x[i-1] + widths[i]) # stack the data -- # replace the value in each level by the cumulative sum of all preceding levels data_stack = np.reshape([float(i) for i in np.ravel(np.cumsum(data_copy, axis=0))], data_shape) # scale the data is needed if scale: data_copy /= data_stack[levels-1] data_stack /= data_stack[levels-1] if heights is not None: print("WARNING: setting scale and heights does not make sense.") heights = None elif heights is not None: data_copy /= data_stack[levels-1] data_stack /= data_stack[levels-1] for i in np.arange(num_bars): data_copy[:,i] *= heights[i] data_stack[:,i] *= heights[i] # plot # if we were given a color map, convert it to a list of colors if isinstance(colors, matplotlib.colors.Colormap): colors = [ colors(i/levels) for i in range(levels)] if edge_colors is None: edge_colors = colors elif not validation_utils.validate_is_sequence(edge_colors, raise_on_invalid=False): edge_colors = np.full(levels, edge_colors, dtype=object) elif len(edge_colors) != len(levels): msg = "The number of edge_colors must match the number of stacks." raise ValueError(msg) # take cae of gaps gapd_widths = [i - gap for i in widths] if stack_labels is None: stack_labels = np.full(levels, '', dtype=object) # bars bars = [] bar = ax.bar(x, data_stack[0], color=colors[0], edgecolor=edge_colors[0], width=gapd_widths, linewidth=0.5, align='center', label=stack_labels[0], log=log ) bars.append(bar) for i in np.arange(1,levels): bar = ax.bar(x, data_copy[i], bottom=data_stack[i-1], color=colors[i], edgecolor=edge_colors[i], width=gapd_widths, linewidth=0.5, align='center', label=stack_labels[i], log=log ) bars.append(bar) # borders ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) #ax.spines["bottom"].set_visible(False) ax.spines["left"].set_visible(False) # make ticks if necessary if y_ticks is not None: ax.set_yticks(y_ticks) if y_tick_labels is not None: ax.set_yticklabels(y_tick_labels, fontsize=font_size) if hide_first_ytick: yticks = ax.yaxis.get_major_ticks() yticks[0].label1.set_visible(False) else: ax.tick_params( axis='y', which='both', left='off', right='off', labelright='off', labelleft='off') if x_tick_labels is not None: ax.tick_params(axis='x', which='both', labelsize=font_size, direction="out") ax.xaxis.tick_bottom() ax.set_xticks(x) ax.set_xticklabels(x_tick_labels, rotation='vertical') else: ax.set_xticks([]) ax.set_xticklabels([]) # limits if end_gaps: ax.set_xlim(-1.*widths[0]/2. - gap/2., np.sum(widths)-widths[0]/2. + gap/2.) else: ax.set_xlim(-1.*widths[0]/2. + gap/2., np.sum(widths)-widths[0]/2. - gap/2.) ymin = 0 if log: ymin = 1 # labels if x_title is not None: ax.set_xlabel(x_title, fontsize=label_font_size) if y_title is not None: ax.set_ylabel(y_title, fontsize=label_font_size) # legend if show_legend: if legend_ncol < 1: legend_ncol = len(stack_labels) lgd = ax.legend(loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor, ncol=legend_ncol, fontsize=legend_font_size) return bars
[docs]def plot_sorted_values( values:Sequence[float], ymin:Optional[float]=None, ymax:Optional[float]=None, ax:Optional[plt.Axes]=None, scale_x:bool=False, **kwargs) -> FigAx: """ Sort `values` and plot them Parameters ---------- values : typing.Sequence[float] The values to plot y_{min,max} : float The min and max values for the y-axis. If not given, then these default to the minimum and maximum values in the list. scale_x : bool If True, then the `x` values will be equally-spaced between 0 and 1. Otherwise, they will be the values 0 to len(values) ax : typing.Optional[matplotlib.axes.Axes] An axis for plotting. If this is not given, then a figure and axis will be created. **kwargs : <key>=<value> pairs Additional keyword arguments to pass to the plot function. Some useful keyword arguments are: * `label` : the label for a legend * `lw` : the line width * `ls` : https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html * `marker` : https://matplotlib.org/examples/lines_bars_and_markers/marker_reference.html Returns ------- fig : matplotlib.figure.Figure The Figure associated with `ax`, or a new Figure ax : matplotlib.axes.Axes Either `ax` or a new Axis """ fig, ax = _get_fig_ax(ax) y = np.sort(values) if scale_x: x = np.linspace(0,1, len(y)) else: x = np.arange(len(y)) ax.plot(x,y, **kwargs) if ymin is None: ymin = y[0] if ymax is None: ymax = y[-1] ax.set_ylim((ymin, ymax)) ax.set_xlim((0, len(y))) return fig, ax
### # High-level, ML and statistics plotting helpers ###
[docs]def plot_binary_prediction_scores( y_scores:Sequence[float], y_true:Sequence[int], positive_label:int=1, positive_line_color='g', negative_line_color='r', line_kwargs:typing.Mapping={}, positive_line_kwargs:typing.Mapping={}, negative_line_kwargs:typing.Mapping={}, title:Optional[str]=None, ylabel:Optional[str]="Score", xlabel:Optional[str]="Instance", title_font_size:int=20, label_font_size:int=15, ticklabels_font_size:int=15, ax:Optional[plt.Axes]=None) -> FigAx: """ Plot separate lines for the scores of the positives and negatives Parameters ---------- y_scores : typing.Sequence[float] The predicted scores of the positive class. For example, this may be found using something like: `y_scores = y_proba_pred[:,1]` for probabilistic predictions from most `sklearn` classifiers. y_true : typing.Sequence[int] The ground truth labels positive_label : int The value for the "positive" class {positive,negative}_line_color : color Values to use for the color of the respective lines. These can be anything which `matplotlib.plot` can interpret. These values have precedent over the other `kwargs` parameters. line_kwargs : typing.Mapping Other keyword arguments passed through to `plot` for both lines. {positive,negative}_line_kwargs : typing.Mapping Other keyword arguments pass through to `plot` for only the respective line. These values have precedent over `line_kwargs`. title : typing.Optional[str] If given, the title of the axis is set to this value {y,x}label : typing.Optional[str] Text for the respective labels {title,label,ticklabels}_font_size : int The font sizes for the respective elements. ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. Returns ------- fig : matplotlib.figure.Figure The figure on which the scores lines were plotted ax : matplotlib.axes.Axes The axis on which the score lines were plotted """ fig, ax = _get_fig_ax(ax) # pull out the positivies m_positives = (y_true == positive_label) y_scores_positive = y_scores[m_positives] y_scores_negative = y_scores[~m_positives] positives_kwargs = {**line_kwargs, **positive_line_kwargs} positives_kwargs['color'] = positive_line_color plot_sorted_values(y_scores_positive, ax=ax, **positives_kwargs) negatives_kwargs = {**line_kwargs, **negative_line_kwargs} negatives_kwargs['color'] = negative_line_color plot_sorted_values(y_scores_negative, ax=ax, **negatives_kwargs) if title is not None: ax.set_title(title, fontsize=title_font_size) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=label_font_size) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=label_font_size) set_ticklabels_fontsize(ax, ticklabels_font_size) return fig, ax
[docs]def plot_confusion_matrix( confusion_matrix:np.ndarray, ax:Optional[plt.Axes]=None, show_cell_labels:bool=True, show_colorbar:bool=True, title:Optional[str]="Confusion matrix", cmap:matplotlib.colors.Colormap=plt.cm.Blues, true_tick_labels:Optional[Sequence[str]]=None, predicted_tick_labels:Optional[Sequence[str]]=None, ylabel:Optional[str]="True labels", xlabel:Optional[str]="Predicted labels", title_font_size:int=20, label_font_size:int=15, true_tick_rotation:Optional[IntOrString]=None, predicted_tick_rotation:Optional[IntOrString]=None, out:Optional[str]=None) -> FigAx: """ Plot the given confusion matrix Parameters ----------- confusion_matrix : numpy.ndarray A 2-d array, presumably from :func:`sklearn.metrics.confusion_matrix` or something similar. The rows (Y axis) are the "true" classes while the columns (X axis) are the "predicted" classes. ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. show_cell_labels : bool Whether to show the values within each cell show_colorbar : bool Whether to show a color bar title : typing.Optional[str] If given, the title of the axis is set to this value cmap : matplotlib.colors.Colormap A colormap to determine the cell colors {true,predicted}_tick_labels : typing.Optional[typing.Sequence[str]] Text for the Y (true) and X (predicted) axis, respectively {y,x}label : typing.Optional[str] Text for the respective labels {title,label}_font_size : int The font sizes for the respective elements. The class labels (on the tick marks) use the `label_font_size`. {true,predicted}_tick_rotation : typing.Optional[IntOrString] The rotation arguments for the respective tick labels. Please see the matplotlib text documentation (https://matplotlib.org/api/text_api.html#matplotlib.text.Text) for more details. out : typing.Optional[str] If given, the plot will be saved to this file. Returns ------- fig : matplotlib.figure.Figure The figure on which the confusion matrix was plotted ax : matplotlib.axes.Axes The axis on which the confusion matrix was plotted """ fig, ax = _get_fig_ax(ax) # a hack to give cmap a default without importing pyplot for arguments if cmap == None: cmap = plt.cm.Blues mappable = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap) if show_colorbar: fig.colorbar(mappable) ax.grid(False) true_tick_marks = np.arange(confusion_matrix.shape[0]) ax.set_yticks(true_tick_marks) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=label_font_size) if true_tick_labels is None: true_tick_labels = list(true_tick_marks) ax.set_yticklabels( true_tick_labels, fontsize=label_font_size, rotation=true_tick_rotation ) predicted_tick_marks = np.arange(confusion_matrix.shape[1]) ax.set_xticks(predicted_tick_marks) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=label_font_size) if predicted_tick_labels is None: predicted_tick_labels = list(predicted_tick_marks) ax.set_xticklabels( predicted_tick_labels, fontsize=label_font_size, rotation=predicted_tick_rotation ) if show_cell_labels: # the choice of color is based on this SO thread: # https://stackoverflow.com/questions/2509443 color_threshold = 125 s = confusion_matrix.shape it = itertools.product(range(s[0]), range(s[1])) for i,j in it: val = confusion_matrix[i,j] cell_color = cmap(mappable.norm(val)) # see the SO thread mentioned above color_intensity = ( (255*cell_color[0] * 299) + (255*cell_color[1] * 587) + (255*cell_color[2] * 114) ) / 1000 font_color = "white" if color_intensity > color_threshold: font_color = "black" text = val ax.text(j, i, text, ha='center', va='center', color=font_color, size=label_font_size) if title is not None: ax.set_title(title, fontsize=title_font_size) fig.tight_layout() if out is not None: plt.savefig(out, bbox_inches='tight') return fig, ax
[docs]def plot_mean_roc_curve( tprs:Sequence[Sequence[float]], fprs:Sequence[Sequence[float]], aucs:Optional[float]=None, label_note:Optional[str]=None, line_style:Mapping={'c':'b', 'lw':2, 'alpha':0.8}, fill_style:Mapping={'color': 'grey', 'alpha':0.2}, show_xy_line:bool=True, xy_line_kwargs:Mapping={'color': 'r', 'ls': '--', 'lw': 2}, ax:Optional[plt.Axes]=None, title:Optional[str]=None, xlabel:Optional[str]="False positive rate", ylabel:Optional[str]="True positive rate", title_font_size:int=25, label_font_size:int=20, ticklabels_font_size:int=20) -> FigAx: """ Plot the mean plus/minus the standard deviation of the given ROC curves Parameters ---------- tprs : typing.Sequence[typing.Sequence[float]] The true positive rate at each threshold fprs : typing.Sequence[typing.Sequence[float]] The false positive rate at each threshold aucs : typing.Optional[float] The calculated area under the ROC curve label_note : typing.Optional[str] A prefix for the label in the legend for this line. {line,fill}_style : typing.Mapping Keyword arguments for plotting the line and `fill_between`, respectively. Please see the mpl docs for more details. show_xy_line : bool Whether to draw the y=x line xy_line_kwargs : typing.Mapping Keyword arguments for plotting the x=y line. title : typing.Optional[str] If given, the title of the axis is set to this value {x,y}label : typing.Optional[str] Text for the respective labels {title,label,ticklabels}_font_size : int The font sizes for the respective elements ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. Returns ------- fig : matplotlib.figure.Figure The figure on which the ROC curves were plotted ax : matplotlib.axes.Axes The axis on which the ROC curves were plotted """ fig, ax = _get_fig_ax(ax) # interpolate across the different curves so we have the same points mean_fpr = np.linspace(0, 1, 100) interp_tprs = [] for tpr, fpr in zip(tprs, fprs): interp_tprs.append(scipy.interp(mean_fpr, fpr, tpr)) interp_tprs[-1][0] = 0.0 mean_tpr = np.mean(interp_tprs, axis=0) mean_tpr[-1] = 1.0 mean_auc = sklearn.metrics.auc(mean_fpr, mean_tpr) std_auc = np.std(aucs) label = "AUC: {:.2f} $\pm$ {:.2f}".format(mean_auc, std_auc) if label_note is not None: label = label_note + label ax.plot(mean_fpr, mean_tpr, label=label, **line_style) std_tpr = np.std(interp_tprs, axis=0) tprs_upper = np.minimum(mean_tpr + std_tpr, 1) tprs_lower = np.maximum(mean_tpr - std_tpr, 0) ax.fill_between(mean_fpr, tprs_lower, tprs_upper, **fill_style) if show_xy_line: ax.plot([0,1], [0,1], label='Luck', **xy_line_kwargs) ax.set_aspect('equal') ax.set_xlim((-0.05, 1.05)) ax.set_ylim((-0.05, 1.05)) if title is not None and len(title) > 0: ax.set_title(title, fontsize=title_font_size) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=label_font_size) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=label_font_size) set_ticklabels_fontsize(ax, ticklabels_font_size) return fig, ax
[docs]def plot_roc_curve( tpr:Sequence[Sequence[float]], fpr:Sequence[Sequence[float]], auc:Optional[float]=None, show_points:bool=True, ax:Optional[plt.Axes]=None, method_names:Optional[Sequence[str]]=None, out:Optional[str]=None, line_colors:Optional[Sequence]=None, point_colors:Optional[Sequence]=None, alphas:Optional[Sequence[float]]=None, line_kwargs:Optional[Mapping]=None, point_kwargs:Optional[Mapping]=None, title:Optional[str]="Receiver operating characteristic curves", xlabel:Optional[str]="False positive rate", ylabel:Optional[str]="True positive rate", title_font_size:int=20, label_font_size:int=15, ticklabels_font_size:int=15) -> FigAx: """ Plot the ROC curve for the given `fpr` and `tpr` values Currently, this function plots multiple ROC curves. Optionally, add a note of the `auc`. Parameters ---------- tpr : typing.Sequence[typing.Sequence[float]] The true positive rate at each threshold fpr : typing.Sequence[typing.Sequence[float]] The false positive rate at each threshold auc : typing.Optional[float] The calculated area under the ROC curve show_points : bool Whether to plot points at each threshold ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. method_names : typing.Optional[typing.Sequence[str]] The name of each method out : typing.Optional[str] If given, the plot will be saved to this file. line_colors : typing.Optional[typing.Sequence[color]] The color of each ROC line point_colors : typing.Optional[typing.Sequence[color]] The color of the points on each each ROC line alphas : typing.Optional[typing.Sequence[float]] An alpha value for each method {line,point}_kwargs : typing.Optional[typing.Mapping] Additional keyword arguments for the respective elements title : typing.Optional[str] If given, the title of the axis is set to this value {x,y}label : typing.Optional[str] Text for the respective labels {title,label,ticklabels}_font_size : int The font sizes for the respective elements Returns ------- fig : matplotlib.figure.Figure The figure on which the ROC curves were plotted ax : matplotlib.axes.Axes The axis on which the ROC curves were plotted """ fig, ax = _get_fig_ax(ax) if alphas is None: alphas = [1.0] * len(tpr) elif len(alphas) != len(tpr): msg = "The ROC curve must have the same number of alpha values as methods" raise ValueError(msg) if line_colors is None: line_colors = ['k'] * len(tpr) elif len(line_colors) != len(tpr): msg = "The ROC curve must have the same number of line colors as methods" raise ValueError(msg) if point_colors is None: point_colors = ['k'] * len(tpr) elif len(point_colors) != len(tpr): msg = "The ROC curve must have the same number of point colors as methods" raise ValueError(msg) for i in range(len(tpr)): l = "" if method_names is not None: l += str(method_names[i]) if auc is not None: l += " " l += "AUC: {:.2f}".format(auc[i]) if show_points: for j in range(1, len(fpr[i])): points_y = [tpr[i][j-1], tpr[i][j]] points_x = [fpr[i][j-1], fpr[i][j]] # this plots the lines connecting each point ax.plot(points_x, points_y, color=line_colors[i], zorder=1, alpha=alphas[i], **line_kwargs) ax.scatter(fpr[i], tpr[i], label=l, c=point_colors[i], alpha=alphas[i], zorder=2, **point_kwargs) else: ax.plot(fpr[i], tpr[i], alpha=alphas[i], c=line_colors[i], label=l, **line_kwargs) # plot ax.plot([0,1], [0,1], label='Luck', color='r', ls='--', lw=2) ax.set_aspect('equal') ax.set_xlim((-0.05, 1.05)) ax.set_ylim((-0.05, 1.05)) if title is not None and len(title) > 0: ax.set_title(title, fontsize=title_font_size) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=label_font_size) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=label_font_size) set_ticklabels_fontsize(ax, ticklabels_font_size) if out is not None: fig.savefig(out, bbox_inches='tight') return fig, ax
[docs]def plot_trend_line( x:Sequence[float], intercept:float, slope:float, power:float, ax:Optional[plt.Axes]=None, **kwargs) -> FigAx: """ Draw the trend line implied by the given coefficients. Parameters ---------- x : typing.Sequence[float] The points at which the function will be evaluated and where the line will be drawn {intercept,slope,power} : float The coefficients of the trend line. Presumably, these come from :func:`pyllars.stats_utils.fit_with_least_squares` or something similar. ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. **kwargs : <key>=<value> pairs Keyword arguments to pass to the ax.plot function (color, etc.). Please consult the matplotlib documentation for more details: https://matplotlib.org/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D Returns ------- fig : matplotlib.figure.Figure The figure on which the trend line was plotted ax : matplotlib.axes.Axes The axis on which the trend line was plotted """ fig, ax = _get_fig_ax(ax) x = np.sort(x) y = power * x ** 2 + slope * x + intercept #Plot trendline ax.plot(x, y, **kwargs) return fig, ax
[docs]def plot_venn_diagram( sets:MapOrSequence, ax:Optional[plt.Axes]=None, set_labels:Optional[Sequence[str]]=None, weighted:bool=False, use_sci_notation:bool=False, sci_notation_limit:float=999, labels_fontsize:int=14, counts_fontsize:int=12) -> matplotlib_venn._common.VennDiagram: """ Wrap the matplotlib_venn package. Please consult the package documentation for more details: https://github.com/konstantint/matplotlib-venn **N.B.** Unlike most of the other high-level plotting helpers, this function returns the venn diagram object rather than the figure and axis objects. Parameters ----------- set : typing.Union[typing.Mapping,typing.Sequence] If a dictionary, it must follow the conventions of `matplotlib_venn`. If a dictionary is given, the number of sets will be guessed based on the length of one of the entries. If a sequence is given, then it must be of length two or three. The type of venn diagram will be based on the number of sets. ax : typing.Optional[matplotlib.axes.Axes] The axis. If not given, then one will be created. set_labels : typing.Optional[typing.Sequence[str]] The label for each set. The order of the labels must match the order of the sets. weighted : bool Whether the diagram is weighted (in which the size of the circles in the venn diagram are based on the number of elements) or unweighted (in which all circles are the same size) use_sci_notation : bool Whether to convert numbers to scientific notation sci_notation_limit : float The maximum number to show before switching to scientific notation {labels,counts}_fontsize : int The respective font sizes Returns --------- venn_diagram : matplotlib_venn._common.VennDiagram The venn diagram """ key_len = 0 if isinstance(sets, dict): random_key = list(sets.keys())[0] key_len = len(random_key) if (len(sets) == 2) or (key_len == 2): if weighted: v = matplotlib_venn.venn2(sets, ax=ax, set_labels=set_labels) else: v = matplotlib_venn.venn2_unweighted(sets, ax=ax, set_labels=set_labels) elif (len(sets) == 3) or (key_len == 3): if weighted: v = matplotlib_venn.venn3(sets, ax=ax, set_labels=set_labels) else: v = matplotlib_venn.venn3_unweighted(sets, ax=ax, set_labels=set_labels) else: msg = "Only two or three sets are supported" raise ValueError(msg) for l in v.set_labels: if l is not None: l.set_fontsize(labels_fontsize) for l in v.subset_labels: if l is None: continue l.set_fontsize(counts_fontsize) if use_sci_notation: val = int(l.get_text()) if val > sci_notation_limit: val = "{:.0E}".format(val) l.set_text(val) return v
### # Other helpers ###
[docs]def add_fontsizes_to_args( args:argparse.Namespace, legend_title_fontsize:int=12, legend_fontsize:int=10, title_fontsize:int=20, label_fontsize:int=15, ticklabels_fontsize:int=10): """ Add reasonable default fontsize values to `args` """ args.legend_title_fontsize = legend_title_fontsize args.legend_fontsize = legend_fontsize args.title_fontsize = title_fontsize args.label_fontsize = label_fontsize args.ticklabels_fontsize = ticklabels_fontsize
[docs]def draw_rectangle( ax:plt.Axes, base_x:float, base_y:float, width:float, height:float, center_x:bool=False, center_y:bool=False, **kwargs) -> FigAx: """ Draw a rectangle at the given x and y coordinates. Optionally, these can be adjusted such that they are the respective centers rather than edge values. Parameters ---------- ax : matplotlib.axes.Axes The axis on which the rectangle will be drawn base_{x,y} : float The base x and y coordinates {width,height} : float The width (change in x) and height (change in y) of the rectangle center_{x,y}: bool Whether to adjust the x and y coordinates such that they become the center rather than lower left. In particular, if `center_x` is `True`, then `base_x` will be shifted left by `width/2`; likewise, if `center_y` is `True`, then `base_y` will be shifted down by `height/2`. **kwargs : key=value pairs Additional keywords are passed to the patches.Rectangle constructor. Please see the matplotlib documentation for more details: https://matplotlib.org/api/_as_gen/matplotlib.patches.Rectangle.html Returns ------- fig : matplotlib.figure.Figure The figure on which the rectangle was drawn ax : matplotlib.axes.Axes The axis on which the rectangle was drawn """ fig, ax = _get_fig_ax(ax) y_offset = 0 if center_y: y_offset = height/2 x_offset = 0 if center_x: x_offset = width/2 y = base_y - y_offset x = base_x - x_offset ax.add_patch(patches.Rectangle((x,y), width, height, **kwargs)) return fig, ax
[docs]def get_diff_counts(data_np): """ This function extracts the differential counts necessary for visualization with stacked_bar_graph. It assumes the counts for each bar are given as a separate row in the numpy 2-d array. Within the rows, the counts are ordered in ascending order. That is, the first column contains the smallest count, the second column contains the next-smallest count, etc. For example, if the columns represnt some sort of filtering approach, then the last column would contain the unfiltered count, the next-to-last column would give the count after the first round of filtering, etc. """ # add an extra column so the diff counts will work zeros = np.zeros((data_np.shape[0], 1)) data_np = np.append(zeros, data_np, axis=1) # get the diffs so the stacks work correctly diff = np.diff(data_np) return diff