Skip to content

Instantly share code, notes, and snippets.

@CBroz1
Created August 5, 2024 17:22
Show Gist options
  • Save CBroz1/ee34516885ac8d3ea6cc00043479030c to your computer and use it in GitHub Desktop.
Save CBroz1/ee34516885ac8d3ea6cc00043479030c to your computer and use it in GitHub Desktop.
"""V1"""
from functools import reduce
from typing import List, Union
import numpy as np
import spikeinterface as si
from spikeinterface.core.job_tools import ChunkRecordingExecutor, ensure_n_jobs
from spyglass.common.common_interval import (
_union_concat,
interval_from_inds,
interval_list_complement,
)
from spyglass.spikesorting.utils import (
_check_artifact_thresholds,
_compute_artifact_chunk,
_init_artifact_worker,
)
from spyglass.utils import logger
def _get_artifact_times(
recording: si.BaseRecording,
sort_interval_valid_times: List[List],
zscore_thresh: Union[float, None] = None,
amplitude_thresh_uV: Union[float, None] = None,
proportion_above_thresh: float = 1.0,
removal_window_ms: float = 1.0,
verbose: bool = False,
**job_kwargs,
):
valid_timestamps = recording.get_times()
# if both thresholds are None, we skip artifract detection
if amplitude_thresh_uV is zscore_thresh is None:
logger.info(
"Amplitude and zscore thresholds are both None, "
+ "skipping artifact detection"
)
return np.asarray(
[valid_timestamps[0], valid_timestamps[-1]]
), np.asarray([])
# verify threshold parameters
(
amplitude_thresh_uV,
zscore_thresh,
proportion_above_thresh,
) = _check_artifact_thresholds(
amplitude_thresh=amplitude_thresh_uV,
zscore_thresh=zscore_thresh,
proportion_above_thresh=proportion_above_thresh,
)
# detect frames that are above threshold in parallel
n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1))
logger.info(f"Using {n_jobs} jobs...")
if n_jobs == 1:
init_args = (
recording,
zscore_thresh,
amplitude_thresh_uV,
proportion_above_thresh,
)
else:
init_args = (
recording.to_dict(),
zscore_thresh,
amplitude_thresh_uV,
proportion_above_thresh,
)
executor = ChunkRecordingExecutor(
recording=recording,
func=_compute_artifact_chunk,
init_func=_init_artifact_worker,
init_args=init_args,
verbose=verbose,
handle_returns=True,
job_name="detect_artifact_frames",
**job_kwargs,
)
artifact_frames = executor.run()
artifact_frames = np.concatenate(artifact_frames)
# turn ms to remove total into s to remove from either side of each
# detected artifact
half_removal_window_s = removal_window_ms / 2 / 1000
if len(artifact_frames) == 0:
recording_interval = np.asarray(
[[valid_timestamps[0], valid_timestamps[-1]]]
)
artifact_times_empty = np.asarray([])
logger.warn("No artifacts detected.")
return recording_interval, artifact_times_empty
# convert indices to intervals
artifact_intervals = interval_from_inds(artifact_frames)
# convert to seconds and pad with window
artifact_intervals_s = np.zeros(
(len(artifact_intervals), 2), dtype=np.float64
)
for interval_idx, interval in enumerate(artifact_intervals):
interv_ind = [
np.searchsorted(
valid_timestamps,
valid_timestamps[interval[0]] - half_removal_window_s,
),
np.searchsorted(
valid_timestamps,
valid_timestamps[interval[1]] + half_removal_window_s,
),
]
artifact_intervals_s[interval_idx] = [
valid_timestamps[interv_ind[0]],
valid_timestamps[interv_ind[1]],
]
# make the artifact intervals disjoint
if len(artifact_intervals_s) > 1:
artifact_intervals_s = reduce(_union_concat, artifact_intervals_s)
# find non-artifact intervals in timestamps
artifact_removed_valid_times = interval_list_complement(
sort_interval_valid_times, artifact_intervals_s, min_length=1
)
artifact_removed_valid_times = reduce(
_union_concat, artifact_removed_valid_times
)
return artifact_removed_valid_times, artifact_intervals_s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment