Coverage for langbrainscore/benchmarks/pereira2018.py: 19%
52 statements
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-06-07 21:22 +0000
1import pandas as pd
2import numpy as np
3from tqdm.auto import tqdm
4import xarray as xr
5from pathlib import Path
6from langbrainscore.utils.logging import log
7from langbrainscore.utils.xarray import collapse_multidim_coord
8from langbrainscore.dataset import Dataset
11def _pereira2018_mean_froi() -> xr.DataArray:
12 """ """
14 source = (
15 Path(__file__).parents[2]
16 / "data/Pereira_FirstSession_TrialEffectSizes_20220223.csv"
17 )
18 mpf = pd.read_csv(source)
19 mpf = mpf.sort_values(by=["UID", "Session", "Experiment", "Stim"])
20 subj_xrs = []
21 neuroidx = 0
22 for uid in tqdm(mpf.UID.unique()):
23 mpf_subj = mpf[mpf.UID == uid]
24 sess_xrs = []
25 for sess in mpf_subj.Session.unique():
26 mpf_sess = mpf_subj[mpf_subj.Session == sess]
27 roi_filt = [any(n in c for n in ["Lang", "MD"]) for c in mpf_sess.columns]
28 mpf_rois = mpf_sess.iloc[:, roi_filt]
29 data_array = np.expand_dims(mpf_rois.values, 2)
30 sess_xr = xr.DataArray(
31 data_array,
32 dims=("sampleid", "neuroid", "timeid"),
33 coords={
34 "sampleid": (
35 np.arange(0, 384)
36 if data_array.shape[0] == 384
37 else np.arange(384, 384 + 243)
38 ),
39 "neuroid": np.arange(neuroidx, neuroidx + data_array.shape[1]),
40 "timeid": np.arange(data_array.shape[2]),
41 "stimulus": ("sampleid", mpf_sess.Sentence.str.strip('"')),
42 "passage": (
43 "sampleid",
44 list(map(lambda p_s: p_s.split("_")[0], mpf_sess.Stim)),
45 ),
46 "experiment": ("sampleid", mpf_sess.Experiment),
47 "session": (
48 "neuroid",
49 np.array(
50 [mpf_sess.Session.values[0]] * data_array.shape[1],
51 dtype=object,
52 ),
53 ),
54 "subject": (
55 "neuroid",
56 [mpf_sess.UID.values[0]] * data_array.shape[1],
57 ),
58 "roi": ("neuroid", mpf_rois.columns),
59 },
60 )
61 sess_xrs.append(sess_xr)
62 neuroidx += data_array.shape[1]
63 subj_xr = xr.concat(sess_xrs, dim="sampleid")
64 subj_xrs.append(subj_xr)
66 mpf_xr = xr.concat(subj_xrs, dim="neuroid")
68 mpf_xr = collapse_multidim_coord(mpf_xr, "stimulus", "sampleid")
69 mpf_xr = collapse_multidim_coord(mpf_xr, "passage", "sampleid")
70 mpf_xr = collapse_multidim_coord(mpf_xr, "experiment", "sampleid")
71 mpf_xr = collapse_multidim_coord(mpf_xr, "session", "neuroid")
73 mpf_xr.attrs["source"] = str(source)
74 mpf_xr.attrs["measurement"] = "fmri"
75 mpf_xr.attrs["modality"] = "text"
76 # mpf_xr.attrs["name"] = f"pereira2018_mean_froi"
78 return mpf_xr
81def pereira2018_mean_froi(network="Lang", load_cache=True) -> Dataset:
82 """ """
83 dataset_name = (
84 f"pereira2018_mean_froi_{network}" if network else "pereira2018_mean_froi"
85 )
87 def package() -> Dataset:
88 mpf_xr = _pereira2018_mean_froi()
89 if network:
90 mpf_xr = mpf_xr.isel(neuroid=mpf_xr.roi.str.contains(network))
91 mpf_dataset = Dataset(
92 mpf_xr,
93 dataset_name=dataset_name,
94 # modality="text"
95 )
96 return mpf_dataset
98 if load_cache:
99 try:
100 mpf_dataset = Dataset(
101 xr.DataArray(),
102 dataset_name=dataset_name,
103 # modality="text",
104 _skip_checks=True,
105 )
106 mpf_dataset.load_cache()
107 except FileNotFoundError:
108 mpf_dataset = package()
109 else:
110 mpf_dataset = package()
111 # mpf_dataset.to_cache()
113 return mpf_dataset