Notebook source code:
notebooks/12_real_world_applications__emg_sign_classification_in_spd_manifold.ipynb
Run it yourself on binder
Hand gesture classification with EMG data using Riemannian metrics#
Lead author: Marius Guerard.
In this notebook we are using EMG time series collected by 8 electrodes placed on the arm skin. We are going to show how to:
Process these kind of signal into covariance matrices that we can manipulate with geomstats tools.
How to apply ML algorithms on this data to classify four different hand gestures present in the data (Rock, Paper, Scissors, Ok).
How do the different methods (using Riemanian metrics, projecting on tangent space, Euclidean metric) compare to each other.
Context#
The data are acquired from somOS-interface: an sEMG armband that allows you to interact via bluetooth with an Android smartphone (you can contact Marius Guerard (marius.guerard@gmail.com) or Renaud Renault (renaud.armand.renault@gmail.com) for more info on how to make this kind of armband yourself).
An example of application is to record static signs that are linked with different actions (moving a cursor and clicking, sign recognition for command based personal assistants, …). In these experiments, we want to evaluate the difference in performance (measured as the accuracy of sign recognition) between three different real life situations where we change the conditions of training (when user record signs or “calibrate” the device) and testing (when the app guess what sign the user is doing):
What is the accuracy when doing sign recognition right after training?
What is the accuracy when calibrating, removing and replacing the armband at the same position and then testing?
What is the accuracy when calibrating, removing the armband and giving it to someone else that is testing it without calibration?
To simulate these situations, we record data from two different users (rr and mg) and in two different sessions (s1 or s2). The user put the bracelet before every session and remove it after every session.
Quick description of the data:
Each row corresponds to one acquisition, there is an acquisition every ~4 ms for 8 electrodes which correspond to a 250Hz acquisition rate.
The time column is in ms.
The columns c0 to c7 correspond to the electrical value recorded at each of the 8 electrodes (arbitrary unit).
The label correspond to the sign being recorded by the user at this time point (‘rest’, ‘rock’, ‘paper’, ‘scissors’, or ‘ok). ‘rest’ correspond to a rested arm.
the exp identify the user (rr and mg) and the session (s1 or s2)
Note: Another interesting use case, not explored in this notebook, would be to test what is the accruacy when calibrating, removing the armband and giving it to someone else that is calibrating it on its own arm before testing it. The idea being that transfer learning might help getting better results (or faster calibration) than calibrating on one user.
Setup#
Before starting this tutorial, we set the working directory to be the root of the geomstats repository. In order to have the code working on your machine, you need to change this path to the path of your geomstats repository.
In [1]:
import os
import subprocess
import matplotlib
matplotlib.interactive(True)
import matplotlib.pyplot as plt
geomstats_gitroot_path = subprocess.check_output(
["git", "rev-parse", "--show-toplevel"], universal_newlines=True
)
os.chdir(geomstats_gitroot_path[:-1])
print("Working directory: ", os.getcwd())
import geomstats.backend as gs
gs.random.seed(2021)
Working directory: /home/marius/proj/geomstats
INFO: Using numpy backend
Parameters#
In [2]:
N_ELECTRODES = 8
N_SIGNS = 4
The Data#
In [3]:
import geomstats.datasets.utils as data_utils
data = data_utils.load_emg()
In [4]:
data.head()
Out [4]:
time | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | label | exp | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 23 | 127 | 123 | 128 | 134 | 125 | 128 | 130 | 124 | rest | mg_s1 |
1 | 28 | 126 | 130 | 128 | 119 | 129 | 128 | 126 | 133 | rest | mg_s1 |
2 | 32 | 129 | 130 | 127 | 125 | 129 | 129 | 127 | 130 | rest | mg_s1 |
3 | 36 | 127 | 128 | 126 | 123 | 128 | 127 | 125 | 131 | rest | mg_s1 |
4 | 40 | 127 | 128 | 129 | 124 | 127 | 129 | 127 | 128 | rest | mg_s1 |
In [5]:
fig, ax = plt.subplots(N_SIGNS, figsize=(20, 20))
label_list = ["rock", "scissors", "paper", "ok"]
for i, label_i in enumerate(label_list):
sign_df = data[data.label == label_i].iloc[:100]
for electrode in range(N_ELECTRODES):
ax[i].plot(sign_df.iloc[:, 1 + electrode])
ax[i].title.set_text(label_i)

We are removing the sign ‘rest’ for the rest of the analysis.
In [6]:
data = data[data.label != "rest"]
Preprocessing into covariance matrices#
In [7]:
import numpy as np
import pandas as pd
### Parameters.
N_STEPS = 100
LABEL_MAP = {"rock": 0, "scissors": 1, "paper": 2, "ok": 3}
MARGIN = 1000
Unpacking data into arrays for batching
In [8]:
data_dict = {
"time": gs.array(data.time),
"raw_data": gs.array(data[["c{}".format(i) for i in range(N_ELECTRODES)]]),
"label": gs.array(data.label),
"exp": gs.array(data.exp),
}
In [9]:
from geomstats.datasets.prepare_emg_data import TimeSeriesCovariance
cov_data = TimeSeriesCovariance(data_dict, N_STEPS, N_ELECTRODES, LABEL_MAP, MARGIN)
cov_data.transform()
We check that these matrics belong to the space of SPD matrices.
In [10]:
import geomstats.geometry.spd_matrices as spd
manifold = spd.SPDMatrices(N_ELECTRODES)
In [11]:
gs.all(manifold.belongs(cov_data.covs))
Out [11]:
True
Covariances plot of the euclidean average#
In [12]:
fig, ax = plt.subplots(2, 2, figsize=(20, 10))
for label_i, i in cov_data.label_map.items():
label_ids = np.where(cov_data.labels == i)[0]
sign_cov_mat = cov_data.covs[label_ids]
mean_cov = np.mean(sign_cov_mat, axis=0)
ax[i // 2, i % 2].matshow(mean_cov)
ax[i // 2, i % 2].title.set_text(label_i)

Looking at the euclidean average of the spd matrices for each sign, does not show a striking difference between 3 of our signs (scissors, paper, and ok). Minimum Distance to Mean (MDM) algorithm will probably performed poorly if using euclidean mean here.
Covariances plot of the Frechet Mean of the affine invariant metric#
In [13]:
from geomstats.learning.frechet_mean import FrechetMean
from geomstats.geometry.spd_matrices import SPDAffineMetric
In [14]:
metric_affine = SPDAffineMetric(N_ELECTRODES)
mean_affine = FrechetMean(metric=metric_affine)
In [15]:
fig, ax = plt.subplots(2, 2, figsize=(20, 10))
for label_i, i in cov_data.label_map.items():
label_ids = np.where(cov_data.labels == i)[0]
sign_cov_mat = cov_data.covs[label_ids]
mean_affine.fit(X=sign_cov_mat)
mean_cov = mean_affine.estimate_
ax[i // 2, i % 2].matshow(mean_cov)
ax[i // 2, i % 2].title.set_text(label_i)

We see that the average matrices computed using the affine invariant metric are now more differenciated from each other and can potentially give better results, when using MDM to predict the sign linked to a matrix sample.
Sign Classification#
We are now going to train some classifiers on those matrices to see how we can accurately discriminate these 4 hand positions. The baseline accuracy is defined as the accuracy we get by randomly guessing the signs. In our case, the baseline accuracy is 25%.
In [16]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate
from sklearn.preprocessing import StandardScaler
In [17]:
# Hiding the numerous sklearn warnings
import warnings
warnings.filterwarnings("ignore")
In [18]:
!pip install tensorflow
Requirement already satisfied: keras in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (2.3.1)
Requirement already satisfied: numpy>=1.9.1 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (1.19.5)
Requirement already satisfied: keras-applications>=1.0.6 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (1.0.8)
Requirement already satisfied: pyyaml in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (5.3)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (1.1.2)
Requirement already satisfied: scipy>=0.14 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (1.4.1)
Requirement already satisfied: h5py in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (2.10.0)
Requirement already satisfied: six>=1.9.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from keras) (1.12.0)
Requirement already satisfied: tensorflow in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (2.4.1)
Requirement already satisfied: typing-extensions~=3.7.4 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (3.7.4.3)
Requirement already satisfied: google-pasta~=0.2 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (0.2.0)
Requirement already satisfied: flatbuffers~=1.12.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.12)
Requirement already satisfied: wrapt~=1.12.1 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.12.1)
Requirement already satisfied: tensorboard~=2.4 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (2.4.1)
Collecting six~=1.15.0 (from tensorflow)
Using cached https://files.pythonhosted.org/packages/ee/ff/48bde5c0f013094d729fe4b0316ba2a24774b3ff1c52d924a8a4cb04078a/six-1.15.0-py2.py3-none-any.whl
Requirement already satisfied: keras-preprocessing~=1.1.2 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.1.2)
Requirement already satisfied: absl-py~=0.10 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (0.11.0)
Requirement already satisfied: tensorflow-estimator<2.5.0,>=2.4.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (2.4.0)
Requirement already satisfied: grpcio~=1.32.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.32.0)
Requirement already satisfied: opt-einsum~=3.3.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (3.3.0)
Requirement already satisfied: numpy~=1.19.2 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.19.5)
Requirement already satisfied: wheel~=0.35 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (0.36.2)
Requirement already satisfied: astunparse~=1.6.3 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.6.3)
Requirement already satisfied: h5py~=2.10.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (2.10.0)
Requirement already satisfied: gast==0.3.3 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (0.3.3)
Requirement already satisfied: protobuf>=3.9.2 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (3.14.0)
Requirement already satisfied: termcolor~=1.1.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorflow) (1.1.0)
Requirement already satisfied: requests<3,>=2.21.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (2.22.0)
Requirement already satisfied: werkzeug>=0.11.15 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (0.15.4)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (0.4.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (1.7.1)
Collecting setuptools>=41.0.0 (from tensorboard~=2.4->tensorflow)
Downloading https://files.pythonhosted.org/packages/ae/4d/153a2cfab2ea03d4f4aee45d9badb52426db9e2275edfb4b825c5dc55a10/setuptools-54.1.0-py3-none-any.whl (784kB)
100% |████████████████████████████████| 788kB 6.1MB/s eta 0:00:01 35% |███████████▎ | 276kB 4.1MB/s eta 0:00:01
Requirement already satisfied: markdown>=2.6.8 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (3.1.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from tensorboard~=2.4->tensorflow) (1.8.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow) (1.25.7)
Requirement already satisfied: certifi>=2017.4.17 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow) (2020.6.20)
Requirement already satisfied: idna<2.9,>=2.5 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow) (2.8)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow) (3.0.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow) (1.3.0)
Requirement already satisfied: cachetools<3.2,>=2.0.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow) (3.1.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow) (0.2.7)
Requirement already satisfied: rsa<4.1,>=3.1.4 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow) (4.0)
Requirement already satisfied: oauthlib>=3.0.0 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow) (3.1.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/marius/miniconda3/envs/stats/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow) (0.4.8)
astroid 2.2.5 requires typed-ast>=1.3.0; implementation_name == "cpython", which is not installed.
Installing collected packages: six, setuptools
Found existing installation: six 1.12.0
Uninstalling six-1.12.0:
Successfully uninstalled six-1.12.0
Found existing installation: setuptools 40.8.0
Uninstalling setuptools-40.8.0:
Successfully uninstalled setuptools-40.8.0
Successfully installed setuptools-54.1.0 six-1.15.0
In [19]:
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
import tensorflow as tf
N_EPOCHS is the number of epochs on which to train the MLP. Recommended is ~100
In [20]:
N_EPOCHS = 10
N_FEATURES = int(N_ELECTRODES * (N_ELECTRODES + 1) / 2)
A. Test on the same session and user as Training/Calibration#
In this first part we are training our model on the same session that we are testing it on. In real life, it corresponds to a user calibrating his armband right before using it. To do this, we are splitting every session in k-folds, training on \((k-1)\) fold to test on the \(k^{th}\) last fold.
In [21]:
class ExpResults:
"""Class handling the score collection and plotting among the different experiments."""
def __init__(self, exps):
self.exps = exps
self.results = {}
self.exp_ids = {}
# Compute the index corresponding to each session only once at initialization.
for exp in set(self.exps):
self.exp_ids[exp] = np.where(self.exps == exp)[0]
def add_result(self, model_name, model, X, y):
"""Add the results from the cross validated pipeline.
For the model 'pipeline', it will add the cross validated results of every session in the model_name
entry of self.results.
Parameters
----------
model_name : str
Name of the pipeline/model that we are adding results from.
model : sklearn.pipeline.Pipeline
sklearn pipeline that we are evaluating.
X : array
data that we are ingesting in the pipeline.
y : array
labels corresponding to the data.
"""
self.results[model_name] = {
"fit_time": [],
"score_time": [],
"test_score": [],
"train_score": [],
}
for exp in self.exp_ids.keys():
ids = self.exp_ids[exp]
exp_result = cross_validate(
pipeline, X[ids], y[ids], return_train_score=True
)
for key in exp_result.keys():
self.results[model_name][key] += list(exp_result[key])
print(
"Average training score: {:.4f}, Average test score: {:.4f}".format(
np.mean(self.results[model_name]["train_score"]),
np.mean(self.results[model_name]["test_score"]),
)
)
def plot_results(
self,
title,
variables,
err_bar=None,
save_name=None,
xlabel="Model",
ylabel="Acc",
):
"""Plot bar plot comparing the different pipelines' results.
Compare the results added previously using the 'add_result' method with bar plots.
Parameters
----------
title : str
Title of the plot.
variables : list of array
List of the variables to plot (e.g. train_score, test_score,...)
err_bar : list of float
list of error to use for plotting error bars. If None, std is used by default.
save_name : str
path to save the plot. If None, plot is not saved.
xlabel : str
Label of the x-axis.
ylabel : str
Label of the y-axis.
"""
### Some defaults parameters.
w = 0.5
colors = ["b", "r", "gray"]
### Reshaping the results for plotting.
x_labels = self.results.keys()
list_vec = []
for variable in variables:
list_vec.append(
np.array(
[self.results[model][variable] for model in x_labels]
).transpose()
)
rand_m1 = lambda size: np.random.random(size) * 2 - 1
### Plots parameters.
label_loc = np.arange(len(x_labels))
center_bar = [w * (i - 0.5) for i in range(len(list_vec))]
### Plots values.
avg_vec = [np.nanmean(vec, axis=0) for vec in list_vec]
if err_bar is None:
err_bar = [np.nanstd(vec, axis=0) for vec in list_vec]
### Plotting the data.
fig, ax = plt.subplots(figsize=(20, 15))
for i, vec in enumerate(list_vec):
label_i = variable[i] + " (n = {})".format(len(vec))
rects = ax.bar(
label_loc + center_bar[i],
avg_vec[i],
w,
label=label_i,
yerr=err_bar[i],
color=colors[i],
alpha=0.6,
)
for j, x in enumerate(label_loc):
ax.scatter(
(x + center_bar[i]) + rand_m1(vec[:, j].size) * w / 4,
vec[:, j],
color=colors[i],
edgecolor="k",
)
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.set_xticks(label_loc)
ax.set_xticklabels(x_labels)
ax.legend()
plt.legend()
### Saving the figure with a timestamp as a name.
if save_name is not None:
plt.savefig(save_name)
In [22]:
exp_arr = data.exp.iloc[cov_data.batches]
intra_sessions_results = ExpResults(exp_arr)
A.0. Using Logistic Regression on the vectorized Matrix (Euclidean Method)#
In [23]:
pipeline = Pipeline(
steps=[
("standardize", StandardScaler()),
("logreg", LogisticRegression(solver="lbfgs", multi_class="multinomial")),
]
)
intra_sessions_results.add_result(
model_name="logreg_eucl", model=pipeline, X=cov_data.covecs, y=cov_data.labels
)
Average training score: 0.9937, Average test score: 0.9165
A.1. Using MLP on the vectorized Matrix (Euclidean Method)#
In [24]:
def create_model(weights="initial_weights.hd5", n_features=N_FEATURES, n_signs=N_SIGNS):
"""Function to create model, required for using KerasClassifier and wrapp a Keras model inside a
scikitlearn form.
We added a weight saving/loading to remove the randomness of the weight initialization (for better comparison).
"""
model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(
n_features, activation="relu", input_shape=(n_features,)
),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(17, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(n_signs, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="rmsprop",
metrics=["accuracy"],
)
if weights is None:
model.save_weights("initial_weights.hd5")
else:
model.load_weights(weights)
return model
def create_model_covariance(weights="initial_weights.hd5"):
return create_model(weights=weights, n_features=N_FEATURES)
Use the line below to generate the ‘initial_weights.hd5’ file
In [25]:
generate_weights = create_model(weights=None)
In [26]:
pipeline = Pipeline(
steps=[
("standardize", StandardScaler()),
("mlp", KerasClassifier(build_fn=create_model, epochs=N_EPOCHS, verbose=0)),
]
)
intra_sessions_results.add_result(
model_name="mlp_eucl", model=pipeline, X=cov_data.covecs, y=cov_data.labels
)
Average training score: 0.9445, Average test score: 0.8083
A.2. Using Tangent space projection + Logistic Regression#
In [27]:
from geomstats.learning.preprocessing import ToTangentSpace
pipeline = Pipeline(
steps=[
("feature_ext", ToTangentSpace(geometry=metric_affine)),
("standardize", StandardScaler()),
("logreg", LogisticRegression(solver="lbfgs", multi_class="multinomial")),
]
)
intra_sessions_results.add_result(
model_name="logreg_affinvariant_tangent",
model=pipeline,
X=cov_data.covs,
y=cov_data.labels,
)
Average training score: 0.9959, Average test score: 0.9200
A.3. Using Tangent space projection + MLP#
In [28]:
pipeline = Pipeline(
steps=[
("feature_ext", ToTangentSpace(geometry=metric_affine)),
("standardize", StandardScaler()),
(
"mlp",
KerasClassifier(
build_fn=create_model_covariance, epochs=N_EPOCHS, verbose=0
),
),
]
)
intra_sessions_results.add_result(
model_name="mlp_affinvariant_tangent",
model=pipeline,
X=cov_data.covs,
y=cov_data.labels,
)
Average training score: 0.9601, Average test score: 0.8358
A.4. Using Euclidean MDM#
In [29]:
from geomstats.learning.mdm import RiemannianMinimumDistanceToMean
from geomstats.geometry.spd_matrices import SPDEuclideanMetric
pipeline = Pipeline(
steps=[
(
"clf",
RiemannianMinimumDistanceToMean(
riemannian_metric=SPDEuclideanMetric(n=N_ELECTRODES)
),
)
]
)
intra_sessions_results.add_result(
model_name="mdm_eucl", model=pipeline, X=cov_data.covs, y=cov_data.labels
)
Average training score: 0.8552, Average test score: 0.7710
A.5. Using Riemannian MDM#
In [30]:
pipeline = Pipeline(
steps=[
(
"clf",
RiemannianMinimumDistanceToMean(
riemannian_metric=SPDAffineMetric(n=N_ELECTRODES)
),
)
]
)
intra_sessions_results.add_result(
model_name="mdm_affinvariant", model=pipeline, X=cov_data.covs, y=cov_data.labels
)
Average training score: 0.9342, Average test score: 0.8353
Summary plots#
In [31]:
intra_sessions_results.plot_results("intra_sess", ["test_score"])
