Source code for process_nwb.pipelines

import numpy as np
import warnings

from hdmf.backends.hdf5.h5_utils import H5DataIO
from pynwb import NWBHDF5IO
from pynwb.ecephys import ElectricalSeries

from process_nwb.common_referencing import subtract_CAR, CAR
from process_nwb.linenoise_notch import apply_linenoise_notch
from process_nwb.resample import resample, store_resample, _scaling as scaling
from process_nwb.utils import dtype
from process_nwb.wavelet_transform import store_wavelet_transform


[docs]def preprocess_block(nwb_path, acq_name='ECoG', initial_resample_rate=3200., final_resample_rate=400., filters='rat', hg_only=True, all_steps=False, logger=None): """This is the default preprocessing pipeline. Perform the following steps: 1) Resample to initial_resample_rate, 2) Remove 60Hz noise and remove the CAR, and 3) Perform and store a wavelet decomposition. 4) Optionally resample the wavelet amplitudes. Parameters ------- nwb_path : str or pathlike Path to the .nwb file. This file will be modified as a result of this function. acq_name : str Name of the acquisition, either 'ECoG' or 'Poly'. initial_resample_rate : float Frequency (in Hz) to resample to before performing wavelet transform. final_resample_rate : float Frequency (in Hz) to resample to after calculating wavelet amplitudes. filters : str Type of filter bank to use for wavelets. Choose from ['rat', 'human', 'changlab']. hg_only : bool Whether to store high gamma bands only. If False, use all filters. all_steps : bool Whether to store intermediate data between preprocessing steps. logger : logger Optional logger passed from upstream. Returns ------- Returns nothing, but changes the NWB file at nwb_path. A ProcessingModule with name 'preprocessing' will be added to the NWB. """ with NWBHDF5IO(nwb_path, 'a') as io: if logger is not None: logger.info('==================================') logger.info(f'Running preprocessing for {nwb_path}.') nwbfile = io.read() try: electrical_series = nwbfile.acquisition[acq_name] except KeyError: # in case NWB file is in a legacy format electrical_series = nwbfile.acquisition['Raw'][acq_name] nwbfile.create_processing_module(name='preprocessing', description='Preprocessing.') if all_steps: if logger is not None: logger.info('Resampling...') _, electrical_series_ds = store_resample(electrical_series, nwbfile.processing['preprocessing'], initial_resample_rate) del _ if logger is not None: logger.info('Filtering and re-referencing...') _, electrical_series_CAR = store_linenoise_notch_CAR(electrical_series_ds, nwbfile.processing['preprocessing']) del _ series = electrical_series_CAR else: rate = electrical_series.rate if logger is not None: logger.info('Resampling...') ts = resample(electrical_series.data[:] * scaling, initial_resample_rate, rate) if logger is not None: logger.info('Filtering and re-referencing...') ts = apply_linenoise_notch(ts, initial_resample_rate) ts = subtract_CAR(ts) electrical_series_CAR = ElectricalSeries('CAR_ln_downsampled_' + electrical_series.name, ts, electrical_series.electrodes, starting_time=electrical_series.starting_time, rate=initial_resample_rate) series = electrical_series if logger is not None: logger.info('Running wavelet transform...') _, electrical_series_wvlt = store_wavelet_transform(electrical_series_CAR, nwbfile.processing['preprocessing'], filters=filters, hg_only=hg_only, post_resample_rate=final_resample_rate, source_series=series) io.write(nwbfile) if logger is not None: logger.info(f'Preprocessing added to {nwb_path}.')
[docs]def store_linenoise_notch_CAR(elec_series, processing, mean_frac=.95, round_func=np.ceil, precision='single'): """Apply a notch filter at 60 Hz and its harmonics, calculate and remove the common average reference (CAR), and finally store the signal and the CAR. Parameters ---------- elec_series : ElectricalSeries ElectricalSeries to process. processing : Processing module NWB Processing module to save processed data. mean_frac : float Fraction of the data to be taken in the mean. 0. < mean_frac <= 1. round_func : callable Function for rounding the fraction of channels. precision : str Either `single` for float32/complex64 or `double` for float/complex. Returns ------- X_CAR_ln : ndarray, (n_time, n_channels) Data with line noise and CAR removed. elec_series_CAR_ln : ElectricalSeries ElectricalSeries that holds X_CAR_ln. """ rate = elec_series.rate X = elec_series.data[:] X_dtype = dtype(X, precision) X = X.astype(X_dtype, copy=False) X_ln = apply_linenoise_notch(X, rate, precision=precision) avg = CAR(X_ln, mean_frac=mean_frac, round_func=round_func, precision=precision) X_CAR_ln = X_ln - avg elec_series_CAR_ln = ElectricalSeries('CAR_ln_' + elec_series.name, H5DataIO(X_CAR_ln, compression=True, shuffle=True, fletcher32=True), elec_series.electrodes, starting_time=elec_series.starting_time, rate=rate, description=('CAR_lned: ' + elec_series.description)) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) CAR_series = ElectricalSeries('CAR_of_' + elec_series.name, H5DataIO(avg, compression=True, shuffle=True, fletcher32=True), elec_series.electrodes, starting_time=elec_series.starting_time, rate=rate, description=('CAR: ' + elec_series.description)) processing.add(elec_series_CAR_ln) processing.add(CAR_series) return X_CAR_ln, elec_series_CAR_ln