Source code for pyllars.dask_utils

"""
This module contains helpers for using dask: https://dask.pydata.org/en/latest/
"""

import logging
logger = logging.getLogger(__name__)

from typing import Callable, Iterable, List

import argparse
import collections
import dask.distributed
import pandas as pd
import shlex
import tqdm
import typing

[docs]def connect(args:argparse.Namespace) -> typing.Tuple[dask.distributed.Client, typing.Optional[dask.distributed.LocalCluster]]: """ Connect to the dask cluster specifed by the arguments in `args` Specifically, this function uses args.cluster_location to determine whether to start a dask.distributed.LocalCluster (in case args.cluster_location is "LOCAL") or to (attempt to) connect to an existing cluster (any other value). If a local cluster is started, it will use a number of worker processes equal to args.num_procs. Each process will use args.num_threads_per_proc threads. The scheduler for the local cluster will listen to a random port. Parameters ---------- args: argparse.Namespace A namespace containing the following fields: * cluster_location * client_restart * num_procs * num_threads_per_proc Returns ------- client: dask.distributed.Client The client for the dask connection cluster: dask.distributed.LocalCluster or None If a local cluster is started, the reference to the local cluster object is returned. Otherwise, None is returned. """ from dask.distributed import Client as DaskClient from dask.distributed import LocalCluster as DaskCluster client = None cluster = None if args.cluster_location == "LOCAL": msg = "[dask_utils]: starting local dask cluster" logger.info(msg) cluster = DaskCluster( n_workers=args.num_procs, processes=True, threads_per_worker=args.num_threads_per_proc ) client = DaskClient(cluster) else: msg = "[dask_utils]: attempting to connect to dask cluster: {}" msg = msg.format(args.cluster_location) logger.info(msg) client = DaskClient(address=args.cluster_location) if args.client_restart: msg = "[dask_utils]: restarting client" logger.info(msg) client.restart() return client, cluster
[docs]def add_dask_options( parser:argparse.ArgumentParser, num_procs:int=1, num_threads_per_proc:int=1, cluster_location:str="LOCAL") -> None: """ Add options for connecting to and/or controlling a local dask cluster Parameters ---------- parser : argparse.ArgumentParser The parser to which the options will be added num_procs : int The default number of processes for a local cluster num_threads_per_proc : int The default number of threads for each process for a local cluster cluster_location : str The default location of the cluster Returns ------- None : None A "dask cluster options" group is added to the parser """ dask_options = parser.add_argument_group("dask cluster options") dask_options.add_argument('--cluster-location', help="The address for the " "cluster scheduler. This should either be \"LOCAL\" or the address " "and port of the scheduler. If \"LOCAL\" is given, then a " "dask.distributed.LocalCluster will be started.", default=cluster_location) dask_options.add_argument('--num-procs', help="The number of processes to use " "for a local cluster", type=int, default=num_procs) dask_options.add_argument('--num-threads-per-proc', help="The number of " "threads to allocate for each process. So the total number of threads " "for a local cluster will be (args.num_procs * " "args.num_threads_per_cpu).", type=int, default=num_threads_per_proc) dask_options.add_argument('--client-restart', help="If this flag is " "given, then the \"restart\" function will be called on the client " "after establishing the connection to the cluster", action='store_true')
[docs]def add_dask_values_to_args( args:argparse.Namespace, num_procs:int=1, num_threads_per_proc:int=1, cluster_location:str="LOCAL", client_restart:bool=False) -> None: """ Add the options for a dask cluster to the given argparse namespace This function is mostly intended as a helper for use in ipython notebooks. Parameters ---------- args : argparse.Namespace The namespace on which the arguments will be set num_procs : int The number of processes for a local cluster num_threads_per_proc : int The number of threads for each process for a local cluster cluster_location : str The location of the cluster client_restart : bool Whether to restart the client after connection Returns ------- None : None The respective options will be set on the namespace """ args.num_procs = num_procs args.num_threads_per_proc = num_threads_per_proc args.cluster_location = cluster_location args.client_restart = client_restart
[docs]def get_dask_cmd_options(args:argparse.Namespace) -> List[str]: """ Extract the flags and options specified for dask from the parsed arguments. Presumably, these were added with `add_dask_options`. This function returns the arguments as an array. Thus, they are suitable for use with `subprocess.run` and similar functions. Parameters ----------- args : argparse.Namespace The parsed arguments Returns ------- dask_options : typing.List[str] The list of dask options and their values. """ args_dict = vars(args) # first, pull out the text arguments dask_options = [ 'num_procs', 'num_threads_per_proc', 'cluster_location' ] # create a list of command line arguments ret = [] for o in dask_options: arg = str(args_dict[o]) if len(arg) > 0: ret.append('--{}'.format(o.replace('_', '-'))) ret.append(arg) if args.client_restart: ret.append("--client-restart") ret = [shlex.quote(c) for c in ret] return ret
### # Helpers to submit arbitrary jobs to a dask cluster ###
[docs]def apply_iter( it:Iterable, client:dask.distributed.Client, func:Callable, *args, return_futures:bool=False, progress_bar:bool=True, priority:int=0, **kwargs) -> List: """ Distribute calls to `func` on each item in `it` across `client`. Parameters ---------- it : typing.Iterable The inputs for `func` client : dask.distributed.Client A dask client func : typing.Callable The function to apply to each item in `it` args Positional arguments to pass to `func` kwargs Keyword arguments to pass to `func` return_futures : bool Whether to wait for the results (`False`, the default) or return a list of dask futures (when `True`). If a list of futures is returned, the `result` method should be called on each of them at some point before attempting to use the results. progress_bar : bool Whether to show a progress bar when waiting for results. The parameter is only relevant when `return_futures` is `False`. priority : int The priority of the submitted tasks. Please see the dask documentation for more details: http://distributed.readthedocs.io/en/latest/priority.html Returns ------- results: typing.List Either the result of each function call or a future which will give the result, depending on the value of `return_futures` """ msg = ("[dask_utils.apply_iter] submitting jobs to cluster") logger.debug(msg) if progress_bar: it = tqdm.tqdm(it) ret_list = [ client.submit(func, *(i, *args), **kwargs, priority=priority) for i in it ] if return_futures: return ret_list msg = ("[dask_utils.apply_iter] collecting results from cluster") logger.debug(msg) # add a progress bar if we asked for one if progress_bar: ret_list = tqdm.tqdm(ret_list) ret_list = [r.result() for r in ret_list] return ret_list
[docs]def apply_df( data_frame:pd.DataFrame, client:dask.distributed.Client, func:typing.Callable, *args, return_futures:bool=False, progress_bar:bool=True, priority:int=0, **kwargs) -> List: """ Distribute calls to `func` on each row in `data_frame` across `client`. Additionally, `args` and `kwargs` are passed to the function call. Parameters ---------- data_frame: pandas.DataFrame A data frame client: dask.distributed.Client A dask client func: typing.Callable The function to apply to each row in `data_frame` args Positional arguments to pass to `func` kwargs Keyword arguments to pass to `func` return_futures: bool Whether to wait for the results (`False`, the default) or return a list of dask futures (when `True`). If a list of futures is returned, the `result` method should be called on each of them at some point before attempting to use the results. progress_bar: bool Whether to show a progress bar when waiting for results. The parameter is only relevant when `return_futures` is `False`. priority : int The priority of the submitted tasks. Please see the dask documentation for more details: http://distributed.readthedocs.io/en/latest/priority.html Returns ------- results: typing.List Either the result of each function call or a future which will give the result, depending on the value of `return_futures` """ if len(data_frame) == 0: return [] it = data_frame.iterrows() if progress_bar: it = tqdm.tqdm(it, total=len(data_frame)) ret_list = [ client.submit(func, *(row[1], *args), **kwargs, priority=priority) for row in it ] if return_futures: return ret_list # add a progress bar if we asked for one if progress_bar: ret_list = tqdm.tqdm(ret_list, total=len(data_frame)) ret_list = [r.result() for r in ret_list] return ret_list
[docs]def apply_groups( groups:pd.core.groupby.DataFrameGroupBy, client:dask.distributed.client.Client, func:typing.Callable, *args, return_futures:bool=False, progress_bar:bool=True, priority:int=0, **kwargs) -> typing.List: """ Distribute calls to `func` on each group in `groups` across `client`. Additionally, `args` and `kwargs` are passed to the function call. Parameters ---------- groups: pandas.DataFrameGroupBy The result of a call to `groupby` on a data frame client: distributed.Client A dask client func: typing.Callable The function to apply to each group in `groups` args Positional arguments to pass to `func` kwargs Keyword arguments to pass to `func` return_futures: bool Whether to wait for the results (`False`, the default) or return a list of dask futures (when `True`). If a list of futures is returned, the `result` method should be called on each of them at some point before attempting to use the results. progress_bar: bool Whether to show a progress bar when waiting for results. The parameter is only relevant when `return_futures` is `False`. priority : int The priority of the submitted tasks. Please see the dask documentation for more details: http://distributed.readthedocs.io/en/latest/priority.html Returns ------- results: typing.List Either the result of each function call or a future which will give the result, depending on the value of `return_futures`. """ if len(groups) == 0: return [] it = groups if progress_bar: it = tqdm.tqdm(it) ret_list = [ client.submit(func, *(group, *args), **kwargs, priority=priority) for name, group in it ] if return_futures: return ret_list # add a progress bar if we asked for one if progress_bar: ret_list = tqdm.tqdm(ret_list) ret_list = [r.result() for r in ret_list] return ret_list
[docs]def check_status(f_list:Iterable[dask.distributed.client.Future]) -> collections.Counter: """ Collect the status counts of a list of futures This is primarily intended to check the status of jobs submitted with the various `apply` functions when `return_futures` is `True`. Parameters ---------- f_list: typing.List[dask.distributed.client.Future] The list of futures Returns ------- status_counter: collections.Counter The number of futures with each status """ counter = collections.Counter([f.status for f in f_list]) return counter
[docs]def collect_results( f_list:Iterable[dask.distributed.client.Future], finished_only:bool=True, progress_bar:bool=False) -> List: """ Collect the results from a list of futures By default, only results from finished tasks will be collected. Thus, the function is (more or less) non-blocking. Parameters ---------- f_list: typing.List[dask.distributed.client.Future] The list of futures finished_only: bool Whether to collect only results for jobs whose status is 'finished' progress_bar : bool Whether to show a progress bar when waiting for results. The parameter is only relevant when `return_futures` is `False`. Returns ------- results: typing.List The results for each (finished, if specified) task """ if progress_bar: f_list = tqdm.tqdm(f_list) if finished_only: ret = [f.result() for f in f_list if f.status == 'finished'] else: ret = [f.result() for f in f_list] return ret
[docs]def cancel_all(f_list:Iterable[dask.distributed.client.Future], pending_only=True) -> None: """ Cancel all (pending) tasks in the list By default, only pending tasks are cancelled. Parameters ---------- f_list : Iterable[dask.distributed.client.Future] The list of futures pending_only : bool Whether to cancel only tasks whose status is 'pending' Returns ------- None : None The specified tasks are cancelled. """ if pending_only: for f in f_list: if f.status == 'pending': f.cancel() else: for f in f_list: f.cancel()
### # A simple wrapper to submit an sklearn pipeline to a dask cluster for fitting ###
[docs]class dask_pipeline: """ This class is a simple wrapper to submit an sklearn pipeline to a dask cluster for fitting. Examples -------- .. code-block:: python my_pipeline = sklearn.pipeline.Pipeline(steps) d_pipeline = dask_pipeline(my_pipeline, dask_client) d_pipeline_fit = d_pipeline.fit(X, y) pipeline_fit = d_pipeline_fit.collect_results() """ def __init__(self, pipeline, dask_client): self.pipeline = pipeline self.dask_client = dask_client
[docs] def fit(self, X, y): """ Submit the call to `fit` of the underlying pipeline to `dask_client` """ self.d_fit = self.dask_client.submit(self.pipeline.fit, X, y) return self
[docs] def collect_results(self): """ Collect the "fit" pipeline from `dask_client`. Then, cleanup the references to the future and client. """ self.pipeline_fit = self.d_fit.result() # and clean up del self.d_fit del self.dask_client return self.pipeline_fit