Source code for process_nwb.wavelet_transform

import numpy as np

from process_nwb.resample import resample
from scipy.fft import fftfreq, fft, ifft

from pynwb.misc import DecompositionSeries
from hdmf.data_utils import AbstractDataChunkIterator, DataChunk
from hdmf.backends.hdf5.h5_utils import H5DataIO

from process_nwb.utils import (_npads, _smart_pad, _trim,
                               log_spaced_cfs, const_Q_sds,
                               chang_sds, dtype)


[docs]def gaussian(n_time, rate, center, sd, precision='single'): """Generates a normalized gaussian kernel. Parameters ---------- n_time : int Number of samples rate : float Sampling rate of kernel (Hz). center : float Center frequency (Hz). sd : float Bandwidth (Hz). precision : str Either `single` for float32/complex64 or `double` for float/complex. """ freq = fftfreq(n_time, 1. / rate) X_dtype = dtype(freq, precision) k = np.exp((-(np.abs(freq) - center) ** 2) / (2 * (sd ** 2))) k /= np.linalg.norm(k) return k.astype(X_dtype, copy=False)
[docs]def hamming(n_time, rate, min_freq, max_freq, precision='single'): """Generates a normalized Hamming kernel. Parameters ---------- n_time : int Number of samples rate : float Sampling rate of kernel (Hz). min_freq : float Band minimum frequency (Hz). max_freq : float Band maximum frequency (Hz). precision : str Either `single` for float32/complex64 or `double` for float/complex. """ freq = fftfreq(n_time, 1. / rate) X_dtype = dtype(freq, precision) pos_in_window = np.logical_and(freq >= min_freq, freq <= max_freq) neg_in_window = np.logical_and(freq <= -min_freq, freq >= -max_freq) k = np.zeros(len(freq)) window_size = np.count_nonzero(pos_in_window) window = np.hamming(window_size) k[pos_in_window] = window window_size = np.count_nonzero(neg_in_window) window = np.hamming(window_size) k[neg_in_window] = window k /= np.linalg.norm(k) return k.astype(X_dtype, copy=False)
[docs]def get_filterbank(filters, n_time, rate, hg_only, precision='single'): """Get the filterbank and parameters. Parameters ---------- filters : str or list Which type of filters to use. Options are 'rat': center frequencies spanning 2-1200 Hz, constant Q, 54 bands 'human': center frequencies spanning 4-200 Hz, constant Q, 40 bands 'changlab': center frequencies spanning 4-200 Hz, variable Q, 40 bands Note - calculating center frequencies above rate/2 raises a ValueError If filters is a list, it is assumed to already be correctly formatted. n_time : int Input data time dimension. rate : float Number of samples per second. hg_only : bool If True, only the amplitudes in the high gamma range [70-150 Hz] is computed. precision : str Either `single` for float32/complex64 or `double` for float/complex. Returns ------- filters : list of ndarrays List of filters to apply. cfs : ndarray Center frequencies used. sds : ndarray Bandwidths used. """ if isinstance(filters, list): return filters, None, None # Calculate center frequencies if filters in ['human', 'changlab']: cfs = log_spaced_cfs(4.0749286538265, 200, 40) elif filters == 'rat': cfs = log_spaced_cfs(2.6308, 1200., 54) else: raise NotImplementedError # Subselect high gamma bands if hg_only: idxs = np.logical_and(cfs >= 70., cfs <= 150.) cfs = cfs[idxs] # Raise exception if sample rate too small if cfs.max() * 2. > np.nextafter(rate, np.inf): # Allow floating point tolerance string = ('Unable to compute wavelet transform above Nyquist rate ({} Hz).' + ' Increase your rate ({} Hz) to at least twice your desired maximum' + 'frequency of interest.') raise ValueError(string.format(cfs.max() * 2., np.nextafter(rate, np.inf))) # Calculate bandwidths if filters in ['rat', 'human']: sds = const_Q_sds(cfs) elif filters == 'changlab': sds = chang_sds(cfs) else: raise NotImplementedError filters = [] for cf, sd in zip(cfs, sds): filters.append(gaussian(n_time, rate, cf, sd, precision=precision)) return filters, cfs, sds
[docs]class ChannelBandIterator(AbstractDataChunkIterator): """Class for iterative write over channels and bands. Parameters ---------- X : ndarray (n_time, n_channels) Data array. filters : str (optional) Which type of filters to use. Options are 'rat': center frequencies spanning 2-1200 Hz, constant Q, 54 bands 'human': center frequencies spanning 4-200 Hz, constant Q, 40 bands 'changlab': center frequencies spanning 4-200 Hz, variable Q, 40 bands npad : int Padding to add to beginning and end of timeseries. Default 0. hg_only : bool If True, only the amplitudes in the high gamma range [70-150 Hz] is computed. post_resample_rate : float If not `None`, resample the computed wavelet amplitudes to this rate. precision : str Either `single` for float32/complex64 or `double` for float/complex. """ def __init__(self, X, rate, filters='rat', npad=None, hg_only=True, post_resample_rate=None, precision='single'): self.X_dtype = dtype(X, precision) X = X.astype(self.X_dtype, copy=False) self.X = X self.rate = rate self.npad = npad self.post_resample_rate = post_resample_rate self.precision = precision # Need to pad X before predicting chunk and filter shape: self.npads, self.to_removes, _ = _npads(X, npad) self.wavelet_time = X.shape[0] + self.npads.sum() self.filterbank, self.cfs, self.sds = get_filterbank(filters, self.wavelet_time, self.rate, hg_only, precision=self.precision) self.resample_time = self.X.shape[0] if post_resample_rate is not None: self.resample_time = int(np.ceil(self.X.shape[0] * post_resample_rate / rate)) self.nch = self.X.shape[1] self.nbands = len(self.filterbank) self._i = 0 def __iter__(self): return self def __next__(self): ch = self._i // self.nbands band = self._i % self.nbands self._i += 1 if ch >= self.X.shape[1]: raise StopIteration if band == 0: self.X_fft_h = None X_ch = self.X[:, [ch]] X_ch = _smart_pad(X_ch, self.npads) data, self.X_fft_h, cfs, sds = wavelet_transform( X_ch, self.rate, filters=[self.filterbank[band]], X_fft_h=self.X_fft_h, npad=0, # padding happens outside to_removes=self.to_removes, precision=self.precision ) if band == 0: data = _trim(data, self.to_removes) data = np.abs(data) if self.post_resample_rate is not None: data = resample(data, self.post_resample_rate, self.rate, precision=self.precision) data = np.squeeze(data) return DataChunk(data=data, selection=np.s_[:data.shape[0], ch, band]) next = __next__ @property def dtype(self): return self.X.dtype @property def maxshape(self): return (None, self.nch, self.nbands)
[docs] def recommended_chunk_shape(self): return (self.resample_time, 1, 1)
[docs] def recommended_data_shape(self): return (self.resample_time, self.nch, self.nbands)
[docs]def wavelet_transform(X, rate, filters='rat', hg_only=True, X_fft_h=None, npad='fast', to_removes=None, precision='single'): """Apply a wavelet transform using a prespecified set of filters. Calculates the center frequencies and bandwidths for the wavelets and applies them along with a heavyside function to the fft of the signal before performing an inverse fft. Here are additional details: **1.** Computes the FFT of the signal and applies 2u(f), where u(f) = heaviside function necessary for calculating analytic signal **2.** Filters in the frequency domain by multiplying with a gaussian kernel (or equivalently a complex morlet wavelet in the time domain). The gaussian kernel location = the center frequency and the standard deviation = the center frequency / Q. For the 'rat' and 'human' filters, Q is constant with a default value of 8 **3.** Computes the IFFT returning the analytic bandpassed (via a complex morlet wavelet) signal for each filter Parameters ---------- X : ndarray (n_time, n_channels) Input data, dimensions rate : float Number of samples per second. filters : str (optional) Which type of filters to use. Options are 'rat': center frequencies spanning 2-1200 Hz, constant Q, 54 bands 'human': center frequencies spanning 4-200 Hz, constant Q, 40 bands 'changlab': center frequencies spanning 4-200 Hz, variable Q, 40 bands Note - calculating center frequencies above rate/2 raises a ValueError hg_only : bool If True, only the amplitudes in the high gamma range [70-150 Hz] is computed. X_fft_h : ndarray (n_time, n_channels) Precomputed product of X_fft and heavyside. Useful for when bands are computed independently. npad : int Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next fastest length. to_removes : int Number of samples to remove at the beginning and end of the timeseries. Default None. Only used if X_fft_h is not None. precision : str Either `single` for float32/complex64 or `double` for float/complex. Returns ------- Xh : ndarray, complex Bandpassed analytic signal X_fft_h : ndarray, complex Product of X_fft and heavyside. cfs : ndarray Center frequencies used. sds : ndarray Bandwidths used. """ if X_fft_h is None: X_dtype = dtype(X, precision) X = X.astype(X_dtype, copy=False) npads, to_removes, _ = _npads(X, npad) X = _smart_pad(X, npads) n_time = X.shape[0] else: n_time = X_fft_h.shape[0] X_fft_h = X_fft_h.astype(dtype(X_fft_h, precision), copy=False) freq = fftfreq(n_time, 1. / rate) filters, cfs, sds = get_filterbank(filters, n_time, rate, hg_only, precision=precision) Xh = np.zeros(X.shape + (len(filters),), dtype=dtype(complex(1.), precision=precision)) if X_fft_h is None: # Heavyside filter with 0 DC h = np.zeros(len(freq)) h[freq > 0] = 2. h = h[:, np.newaxis] X_fft_h = fft(X, axis=0, workers=-1) * h for ii, f in enumerate(filters): if f is None: Xh[..., ii] = ifft(X_fft_h, axis=0, workers=-1) else: f = f / np.linalg.norm(f) Xh[..., ii] = ifft(X_fft_h * f[:, np.newaxis], axis=0, workers=-1) Xh = _trim(Xh, to_removes) return Xh, X_fft_h, cfs, sds
[docs]def store_wavelet_transform(elec_series, processing, filters='rat', hg_only=True, abs_only=True, npad='fast', post_resample_rate=None, chunked=True, precision='single', source_series=None): """Apply a wavelet transform using a prespecified set of filters. Results are stored in the NWB file as a `DecompositionSeries`. Calculates the center frequencies and bandwidths for the wavelets and applies them along with a heavyside function to the fft of the signal before performing an inverse fft. The center frequencies and bandwidths are also stored in the NWB file. Parameters ---------- elec_series : ElectricalSeries ElectricalSeries to process. processing : Processing module NWB Processing module to save processed data. filters : str (optional) Which type of filters to use. Options are 'rat': center frequencies spanning 2-1200 Hz, constant Q, 54 bands 'human': center frequencies spanning 4-200 Hz, constant Q, 40 bands 'changlab': center frequencies spanning 4-200 Hz, variable Q, 40 bands hg_only : bool If True, only the amplitudes in the high gamma range [70-150 Hz] is computed. abs_only : bool If True, only the amplitude is stored. npad : int Padding to add to beginning and end of timeseries. Default 'fast', which pads to the next fastest length. post_resample_rate : float If not `None`, resample the computed wavelet amplitudes to this rate. chunked : bool If True, calculate wavelet transform one channel and band at a time and store iteratively into nwb. Default True precision : str Either `single` for float32/complex64 or `double` for float/complex. Default single. source_series : ElectricalSeries If not None, this series gets used as the source rather than `elec_series`. Can be used if not all intermediate series are being stored in the NWB during preprocessing. Returns ------- X_wvlt : ndarray, complex Complex wavelet coefficients. series : list of DecompositionSeries List of NWB objects. """ X = elec_series.data[:] X_dtype = dtype(X, precision) X = X.astype(X_dtype, copy=False) rate = elec_series.rate if source_series is None: source_series = elec_series final_rate = rate if post_resample_rate is not None: final_rate = post_resample_rate if chunked: if not abs_only: raise NotImplementedError("Phase is not implemented for chunked wavelet transform.") X_wvlt_abs = ChannelBandIterator(X, rate, filters=filters, npad=npad, hg_only=hg_only, post_resample_rate=post_resample_rate, precision=precision) cfs = X_wvlt_abs.cfs sds = X_wvlt_abs.sds elec_series_wvlt_amp = DecompositionSeries('wvlt_amp_' + elec_series.name, H5DataIO(X_wvlt_abs, compression=True, shuffle=True, fletcher32=True), metric='amplitude', source_timeseries=source_series, starting_time=elec_series.starting_time, rate=final_rate, description=('Wavlet: ' + elec_series.description)) output_series = [elec_series_wvlt_amp] X_wvlt = None else: X_wvlt, _, cfs, sds = wavelet_transform(X, rate, filters=filters, hg_only=hg_only, npad=npad, precision=precision) amplitude = abs(X_wvlt) if post_resample_rate is not None: amplitude = resample(amplitude, post_resample_rate, rate, precision=precision) X_wvlt = amplitude rate = post_resample_rate elec_series_wvlt_amp = DecompositionSeries('wvlt_amp_' + elec_series.name, H5DataIO(amplitude, compression=True, shuffle=True, fletcher32=True), metric='amplitude', source_timeseries=source_series, starting_time=elec_series.starting_time, rate=final_rate, description=('Wavlet: ' + elec_series.description)) output_series = [elec_series_wvlt_amp] if not abs_only: if post_resample_rate is not None: raise ValueError('Wavelet phase should not be resampled.') elec_series_wvlt_phase = DecompositionSeries('wvlt_phase_' + elec_series.name, H5DataIO(np.angle(X_wvlt), compression=True, shuffle=True, fletcher32=True), metric='phase', source_timeseries=source_series, starting_time=elec_series.starting_time, rate=final_rate, description=('Wavlet: ' + elec_series.description)) output_series.append(elec_series_wvlt_phase) for es in output_series: for ii, (cf, sd) in enumerate(zip(cfs, sds)): es.add_band(band_name=str(ii), band_mean=cf, band_stdev=sd, band_limits=(-1, -1)) processing.add(es) return X_wvlt, output_series