Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge IFR #265

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ notifications:
env:
- TRAVIS_CI=1

# set build size to avoid 137 out of memory error
vm:
size: large

# conda setup copied from the conda docs
install:
# We do this conditionally because it saves us some downloading if the
Expand Down
9 changes: 9 additions & 0 deletions ramutils/cli/expconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,17 @@ def validate_stim_settings(args):
if args.experiment != "AmplitudeDetermination" and not args.experiment.startswith('PS'):
if args.target_amplitudes is None:
raise RuntimeError("--target-amplitudes is required")
if not (args.min_amplitudes is None and args.max_amplitudes is None):
raise RuntimeError('--min-amplitudes and --max-amplitudes are only used with '
'AmplitudeDetermination and "PS" experiments. To specify '
'a config for ' + args.experiment + ', only specify '
'--target-amplitudes')
valid = len(args.anodes) == len(args.target_amplitudes)
else:
if not args.target_amplitudes is None:
raise RuntimeError('Cannot specify --target-amplitudes (which is only used in '
'"PS" and AtmplitudeDetermination experiments) with '
'--min-amplitudes or --max-amplitudes!')
valid = len(args.anodes) == len(
args.min_amplitudes) == len(args.max_amplitudes)

Expand Down
12 changes: 11 additions & 1 deletion ramutils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
'CatFR1',
'PAL1',
'DBOY1',
'RepFR1'
'RepFR1',
'IFR1',
'ICatFR1'
],
'ps': [
'PS4_FR5',
Expand All @@ -38,6 +40,12 @@
'CatFR6',
'TICL_FR',
'TICL_CatFR',
'IFR3',
'ICatFR3',
'IFR5',
'ICatFR5',
'IFR6',
'ICatFR6',
],

# Experiments that allow multiple stim locations
Expand All @@ -48,6 +56,8 @@
'PS4_CatFR5',
'FR6',
'CatFR6',
'IFR6',
'ICatFR6',
]
}

Expand Down
62 changes: 59 additions & 3 deletions ramutils/tasks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ramutils.tasks import task
from ramutils.utils import extract_experiment_series

from ..utils import insert_column
import numpy as np

__all__ = [
'get_word_event_mask',
'subset_events',
Expand Down Expand Up @@ -67,8 +70,36 @@ def build_training_data(subject, experiment, paths, sessions=None, **kwargs):
pre=kwargs['pre_event_buf'],
post=kwargs['post_event_buf'])

ifr_events = load_events(subject, 'IFR1', sessions=sessions,
rootdir=paths.root)
cleaned_ifr_events = clean_events(ifr_events,
start_time=kwargs['baseline_removal_start_time'],
end_time=kwargs['retrieval_time'],
duration=kwargs['empty_epoch_duration'],
pre=kwargs['pre_event_buf'],
post=kwargs['post_event_buf'])
if 'iscorrect' not in cleaned_ifr_events.dtype.names:
iscorrect_index = fr_events.dtype.names.index('iscorrect')
cleaned_ifr_events = insert_column(cleaned_ifr_events, 'iscorrect',
np.full(cleaned_ifr_events.shape, -999), int, iscorrect_index)

icatfr_events = load_events(subject, 'ICatFR1',
sessions=sessions,
rootdir=paths.root)
cleaned_icatfr_events = clean_events(icatfr_events,
start_time=kwargs['baseline_removal_start_time'],
end_time=kwargs['retrieval_time'],
duration=kwargs['empty_epoch_duration'],
pre=kwargs['pre_event_buf'],
post=kwargs['post_event_buf'])

if 'iscorrect' not in cleaned_icatfr_events.dtype.names:
iscorrect_index = catfr_events.dtype.names.index('iscorrect')
cleaned_icatfr_events = insert_column(cleaned_icatfr_events, 'iscorrect',
np.full(cleaned_icatfr_events.shape, -999), int, iscorrect_index)

free_recall_events = concatenate_events_across_experiments(
[cleaned_fr_events, cleaned_catfr_events], cat=True)
[cleaned_fr_events, cleaned_catfr_events, cleaned_ifr_events, cleaned_icatfr_events], cat=True)

elif "FR" in experiment and not kwargs['combine_events']:
free_recall_events = load_events(subject, experiment, sessions=sessions,
Expand Down Expand Up @@ -140,13 +171,38 @@ def build_test_data(subject, experiment, paths, joint_report, sessions=None,
pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'],
return_stim_events=True)

# Immediate Free Recall variants
ifr_events = load_events(subject, 'IFR' + series_num,
sessions=sessions,
rootdir=paths.root)
cleaned_ifr_events, ifr_stim_params = clean_events(
ifr_events, start_time=kwargs['baseline_removal_start_time'],
end_time=kwargs['retrieval_time'],
duration=kwargs['empty_epoch_duration'],
pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'],
return_stim_events=True)

icatfr_events = load_events(subject, 'ICatFR' + series_num,
sessions=sessions,
rootdir=paths.root)
cleaned_icatfr_events, icatfr_stim_params = clean_events(
icatfr_events, start_time=kwargs['baseline_removal_start_time'],
end_time=kwargs['retrieval_time'],
duration=kwargs['empty_epoch_duration'],
pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'],
return_stim_events=True)

all_events = concatenate_events_across_experiments([fr_events,
catfr_events])
catfr_events,
ifr_events,
icatfr_events])
task_events = concatenate_events_across_experiments(
[cleaned_fr_events, cleaned_catfr_events], cat=True)

stim_params = concatenate_events_across_experiments([fr_stim_params,
catfr_stim_params],
catfr_stim_params,
ifr_stim_params,
icatfr_stim_params],
stim=True)

elif not joint_report and 'FR' in experiment:
Expand Down
2 changes: 1 addition & 1 deletion ramutils/tasks/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def summarize_stim_sessions(all_events, task_events, stim_params, pairs_data,
# TODO: Add some sort of data quality check here potentially. Do the
# observed stim items match what we expect from classifier output?

if experiment in ['FR3', 'FR5', 'catFR3', 'catFR5', 'FR6', 'catFR6']:
if experiment in ['FR3', 'FR5', 'catFR3', 'catFR5', 'FR6', 'catFR6', 'ICatFR5', 'FR5', 'ICatFR6', 'FR6']:
stim_events = dataframe_to_recarray(stim_df, expected_dtypes)
stim_session_summary = FRStimSessionSummary()
stim_session_summary.populate(
Expand Down
29 changes: 29 additions & 0 deletions ramutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,32 @@ def encode_file(fd):
"""
fd.seek(0)
return base64.b64encode(fd.read()).decode()


def insert_column(recarr, column_name, data, dtype, position):
"""
Insert a column into a recarray at a specific position.

Parameters:
- recarr: the original recarray
- column_name: the name of the new column to be added
- data: the data for the new column
- dtype: the data type for the new column
- position: the position to insert the new column at
Returns:
- new_arr: a new recarray with the inserted column
"""

if position > len(recarr.dtype.names):
raise ValueError("Position is out of range.")

before = [(name, recarr.dtype.fields[name][0]) for name in recarr.dtype.names[:position]]
after = [(name, recarr.dtype.fields[name][0]) for name in recarr.dtype.names[position:]]
new_dtype = np.dtype(before + [(column_name, dtype)] + after)

new_arr = np.rec.array(np.zeros(recarr.shape, dtype=new_dtype))

for name in recarr.dtype.names:
new_arr[name] = recarr[name]
new_arr[column_name] = data
return new_arr