You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_MatrixClassifier.hpp 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. ///-------------------------------------------------------------------------------------------------
  2. ///
  3. /// \file test_MatrixClassifier.hpp
  4. /// \brief Tests for Matrix Classifiers.
  5. /// \author Thibaut Monseigne (Inria).
  6. /// \version 1.0.
  7. /// \date 09/01/2019.
  8. /// \copyright <a href="https://choosealicense.com/licenses/agpl-3.0/">GNU Affero General Public License v3.0</a>.
  9. /// \remarks
  10. /// - For this tests I compare the results with the <a href="https://github.com/alexandrebarachant/pyRiemann">pyRiemann</a> library (<a href="https://github.com/alexandrebarachant/pyRiemann/blob/master/LICENSE">License</a>) or <a href="http://scikit-learn.org">sklearn</a> if pyRiemman just redirect the function.
  11. /// - For the adaptation Classification tests I compare the results with the <a href="https://github.com/alexandrebarachant/covariancetoolbox">covariancetoolbox</a> Matlab library (<a href="https://github.com/alexandrebarachant/covariancetoolbox/blob/master/COPYING">License</a>).
  12. /// - The Matlab toolbox is older and Riemannian mean estimation is diff�rent the test are adapted to switch between the two library
  13. ///
  14. ///-------------------------------------------------------------------------------------------------
  15. #pragma once
  16. #include "gtest/gtest.h"
  17. #include "misc.hpp"
  18. #include "init.hpp"
  19. #include <geometry/classifier/CMatrixClassifierMDM.hpp>
  20. #include <geometry/classifier/CMatrixClassifierMDMRebias.hpp>
  21. #include <geometry/classifier/CMatrixClassifierFgMDM.hpp>
  22. #include <geometry/classifier/CMatrixClassifierFgMDMRT.hpp>
  23. #include <geometry/classifier/CMatrixClassifierFgMDMRTRebias.hpp>
  24. static const std::vector<std::vector<double>> EMPTY_DIST;
  25. //---------------------------------------------------------------------------------------------------
  26. static void TestClassify(Geometry::IMatrixClassifier& calc, const std::vector<std::vector<Eigen::MatrixXd>>& dataset, const std::vector<size_t>& prediction,
  27. const std::vector<std::vector<double>>& predictionDistance, const Geometry::EAdaptations& adapt)
  28. {
  29. Eigen::MatrixXd result = Eigen::MatrixXd::Zero(NB_CLASS, NB_CLASS);
  30. size_t idx = 0;
  31. for (size_t k = 0; k < dataset.size(); ++k)
  32. {
  33. for (size_t i = 0; i < dataset[k].size(); ++i)
  34. {
  35. const std::string text = "sample [" + std::to_string(k) + "][" + std::to_string(i) + "]";
  36. size_t classid = 0;
  37. std::vector<double> distance, probability;
  38. EXPECT_TRUE(calc.classify(dataset[k][i], classid, distance, probability, adapt, k)) << "Error during Classify " << text;
  39. if (idx < prediction.size()) { EXPECT_TRUE(prediction[idx] == classid) << ErrorMsg("Prediction " + text, prediction[idx], classid); }
  40. if (idx < predictionDistance.size())
  41. {
  42. EXPECT_TRUE(isAlmostEqual(predictionDistance[idx], distance)) << ErrorMsg("Prediction Distance " + text, predictionDistance[idx], distance);
  43. }
  44. idx++;
  45. result(k, classid)++;
  46. }
  47. }
  48. std::cout << "***** Classifier : *****" << std::endl << calc << std::endl << "***** Result : *****" << std::endl << result << std::endl;
  49. }
  50. //---------------------------------------------------------------------------------------------------
  51. //---------------------------------------------------------------------------------------------------
  52. class Tests_MatrixClassifier : public testing::Test
  53. {
  54. protected:
  55. std::vector<std::vector<Eigen::MatrixXd>> m_dataSet;
  56. void SetUp() override { m_dataSet = InitCovariance::LWF::Reference(); }
  57. };
  58. //---------------------------------------------------------------------------------------------------
  59. //---------------------------------------------------------------------------------------------------
  60. TEST_F(Tests_MatrixClassifier, MDM_Train)
  61. {
  62. const Geometry::CMatrixClassifierMDM ref = InitMatrixClassif::MDM::Reference();
  63. Geometry::CMatrixClassifierMDM calc;
  64. EXPECT_TRUE(calc.train(m_dataSet)) << "Error during Training : " << std::endl << calc << std::endl;
  65. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Train", ref, calc);
  66. }
  67. //---------------------------------------------------------------------------------------------------
  68. //---------------------------------------------------------------------------------------------------
  69. TEST_F(Tests_MatrixClassifier, MDM_Classifify)
  70. {
  71. Geometry::CMatrixClassifierMDM calc = InitMatrixClassif::MDM::ReferenceMatlab();
  72. TestClassify(calc, m_dataSet, InitMatrixClassif::MDM::Prediction(), InitMatrixClassif::MDM::PredictionDistance(), Geometry::EAdaptations::None);
  73. const Geometry::CMatrixClassifierMDM ref = InitMatrixClassif::MDM::ReferenceMatlab(); // No Change
  74. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Classify Change without adaptation mode", ref, calc);
  75. }
  76. //---------------------------------------------------------------------------------------------------
  77. //---------------------------------------------------------------------------------------------------
  78. TEST_F(Tests_MatrixClassifier, MDM_Classifify_Adapt_Supervised)
  79. {
  80. Geometry::CMatrixClassifierMDM calc = InitMatrixClassif::MDM::ReferenceMatlab();
  81. TestClassify(calc, m_dataSet, InitMatrixClassif::MDM::PredictionSupervised(), InitMatrixClassif::MDM::PredictionDistanceSupervised(),
  82. Geometry::EAdaptations::Supervised);
  83. const Geometry::CMatrixClassifierMDM ref = InitMatrixClassif::MDM::AfterSupervised();
  84. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Adapt Classify after Supervised adaptation", ref, calc);
  85. }
  86. //---------------------------------------------------------------------------------------------------
  87. //---------------------------------------------------------------------------------------------------
  88. TEST_F(Tests_MatrixClassifier, MDM_Classifify_Adapt_Unsupervised)
  89. {
  90. Geometry::CMatrixClassifierMDM calc = InitMatrixClassif::MDM::ReferenceMatlab();
  91. TestClassify(calc, m_dataSet, InitMatrixClassif::MDM::PredictionUnSupervised(), InitMatrixClassif::MDM::PredictionDistanceUnSupervised(),
  92. Geometry::EAdaptations::Unsupervised);
  93. const Geometry::CMatrixClassifierMDM ref = InitMatrixClassif::MDM::AfterUnSupervised();
  94. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Adapt Classify after Unsupervised adaptation", ref, calc);
  95. }
  96. //---------------------------------------------------------------------------------------------------
  97. //---------------------------------------------------------------------------------------------------
  98. TEST_F(Tests_MatrixClassifier, MDM_Save)
  99. {
  100. Geometry::CMatrixClassifierMDM calc;
  101. const Geometry::CMatrixClassifierMDM ref = InitMatrixClassif::MDM::Reference();
  102. EXPECT_TRUE(ref.saveXML("test_MDM_Save.xml")) << "Error during Saving : " << std::endl << ref << std::endl;
  103. EXPECT_TRUE(calc.loadXML("test_MDM_Save.xml")) << "Error during Loading : " << std::endl << calc << std::endl;
  104. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Save", ref, calc);
  105. }
  106. //---------------------------------------------------------------------------------------------------
  107. //---------------------------------------------------------------------------------------------------
  108. TEST_F(Tests_MatrixClassifier, FgMDMRT_Train)
  109. {
  110. const Geometry::CMatrixClassifierFgMDMRT ref = InitMatrixClassif::FgMDMRT::Reference();
  111. Geometry::CMatrixClassifierFgMDMRT calc;
  112. EXPECT_TRUE(calc.train(m_dataSet)) << "Error during Training : " << std::endl << calc << std::endl;
  113. EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Train", ref, calc);
  114. }
  115. //---------------------------------------------------------------------------------------------------
  116. //---------------------------------------------------------------------------------------------------
  117. TEST_F(Tests_MatrixClassifier, FgMDMRT_Classifify)
  118. {
  119. Geometry::CMatrixClassifierFgMDMRT calc = InitMatrixClassif::FgMDMRT::Reference();
  120. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRT::Prediction(), InitMatrixClassif::FgMDMRT::PredictionDistance(), Geometry::EAdaptations::None);
  121. const Geometry::CMatrixClassifierFgMDMRT ref = InitMatrixClassif::FgMDMRT::Reference();
  122. EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Classify Change without adaptation mode", ref, calc);
  123. }
  124. //---------------------------------------------------------------------------------------------------
  125. //---------------------------------------------------------------------------------------------------
  126. TEST_F(Tests_MatrixClassifier, FgMDMRT_Classifify_Adapt_Supervised)
  127. {
  128. Geometry::CMatrixClassifierFgMDMRT calc = InitMatrixClassif::FgMDMRT::Reference();
  129. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRT::PredictionSupervised(), EMPTY_DIST, Geometry::EAdaptations::Supervised);
  130. //const Geometry::CMatrixClassifierFgMDMRT ref = InitMatrixClassif::FgMDMRT::AfterSupervised();
  131. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Adapt Classify after Supervised RT adaptation", ref, calc);
  132. }
  133. //---------------------------------------------------------------------------------------------------
  134. //---------------------------------------------------------------------------------------------------
  135. TEST_F(Tests_MatrixClassifier, FgMDMRT_Classifify_Adapt_Unsupervised)
  136. {
  137. Geometry::CMatrixClassifierFgMDMRT calc(InitMatrixClassif::FgMDMRT::Reference());
  138. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRT::PredictionUnSupervised(), EMPTY_DIST, Geometry::EAdaptations::Unsupervised);
  139. //const Geometry::CMatrixClassifierFgMDMRT ref = InitMatrixClassif::FgMDMRT::AfterUnSupervised();
  140. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Adapt Classify after Unsupervised RT adaptation", ref, calc);
  141. }
  142. //---------------------------------------------------------------------------------------------------
  143. //---------------------------------------------------------------------------------------------------
  144. TEST_F(Tests_MatrixClassifier, FgMDMRT_Save)
  145. {
  146. Geometry::CMatrixClassifierFgMDMRT calc;
  147. const Geometry::CMatrixClassifierFgMDMRT ref = InitMatrixClassif::FgMDMRT::Reference();
  148. EXPECT_TRUE(ref.saveXML("test_FgMDM_Save.xml")) << "Error during Saving : " << std::endl << ref << std::endl;
  149. EXPECT_TRUE(calc.loadXML("test_FgMDM_Save.xml")) << "Error during Loading : " << std::endl << calc << std::endl;
  150. EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Save", ref, calc);
  151. }
  152. //---------------------------------------------------------------------------------------------------
  153. //---------------------------------------------------------------------------------------------------
  154. TEST_F(Tests_MatrixClassifier, FgMDM_Classifify_Adapt_Supervised)
  155. {
  156. Geometry::CMatrixClassifierFgMDM calc = InitMatrixClassif::FgMDM::Reference();
  157. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDM::PredictionSupervised(), EMPTY_DIST, Geometry::EAdaptations::Supervised);
  158. //const Geometry::CMatrixClassifierFgMDM ref = InitMatrixClassif::FgMDM::AfterSupervised();
  159. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Adapt Classify after Supervised adaptation", ref, calc);
  160. }
  161. //---------------------------------------------------------------------------------------------------
  162. //---------------------------------------------------------------------------------------------------
  163. TEST_F(Tests_MatrixClassifier, FgMDM_Classifify_Adapt_Unsupervised)
  164. {
  165. Geometry::CMatrixClassifierFgMDM calc = InitMatrixClassif::FgMDM::Reference();
  166. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDM::PredictionUnSupervised(), EMPTY_DIST, Geometry::EAdaptations::Unsupervised);
  167. //const Geometry::CMatrixClassifierFgMDM ref = InitMatrixClassif::FgMDM::AfterUnSupervised();
  168. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Adapt Classify after Unsupervised adaptation", ref, calc);
  169. }
  170. //---------------------------------------------------------------------------------------------------
  171. //---------------------------------------------------------------------------------------------------
  172. TEST_F(Tests_MatrixClassifier, MDM_Rebias_Train)
  173. {
  174. Geometry::CMatrixClassifierMDMRebias calc;
  175. EXPECT_TRUE(calc.train(m_dataSet)) << "Error during Training : " << std::endl << calc << std::endl;
  176. //const Geometry::CMatrixClassifierMDMRebias ref = InitMatrixClassif::MDMRebias::Reference();
  177. //EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Rebias Train", ref, calc); // The mean method is different in matlab toolbox and python toolbox
  178. }
  179. //---------------------------------------------------------------------------------------------------
  180. //---------------------------------------------------------------------------------------------------
  181. TEST_F(Tests_MatrixClassifier, MDM_Rebias_Classifify)
  182. {
  183. Geometry::CMatrixClassifierMDMRebias calc = InitMatrixClassif::MDMRebias::Reference();
  184. TestClassify(calc, m_dataSet, InitMatrixClassif::MDMRebias::Prediction(), InitMatrixClassif::MDMRebias::PredictionDistance(), Geometry::EAdaptations::None);
  185. const Geometry::CMatrixClassifierMDMRebias ref = InitMatrixClassif::MDMRebias::After(); // No Class change but Rebias yes
  186. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Rebias Classify Change without adaptation mode", ref, calc);
  187. }
  188. //---------------------------------------------------------------------------------------------------
  189. //---------------------------------------------------------------------------------------------------
  190. TEST_F(Tests_MatrixClassifier, MDM_Rebias_Classifify_Adapt_Supervised)
  191. {
  192. Geometry::CMatrixClassifierMDMRebias calc = InitMatrixClassif::MDMRebias::Reference();
  193. TestClassify(calc, m_dataSet, InitMatrixClassif::MDMRebias::PredictionSupervised(), InitMatrixClassif::MDMRebias::PredictionDistanceSupervised(),
  194. Geometry::EAdaptations::Supervised);
  195. const Geometry::CMatrixClassifierMDMRebias ref = InitMatrixClassif::MDMRebias::AfterSupervised();
  196. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Rebias Adapt Classify after Supervised adaptation", ref, calc);
  197. }
  198. //---------------------------------------------------------------------------------------------------
  199. //---------------------------------------------------------------------------------------------------
  200. TEST_F(Tests_MatrixClassifier, MDM_Rebias_Classifify_Adapt_Unsupervised)
  201. {
  202. Geometry::CMatrixClassifierMDMRebias calc = InitMatrixClassif::MDMRebias::Reference();
  203. TestClassify(calc, m_dataSet, InitMatrixClassif::MDMRebias::PredictionUnSupervised(), InitMatrixClassif::MDMRebias::PredictionDistanceUnSupervised(),
  204. Geometry::EAdaptations::Unsupervised);
  205. const Geometry::CMatrixClassifierMDMRebias ref = InitMatrixClassif::MDMRebias::AfterUnSupervised();
  206. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Rebias Adapt Classify after Unsupervised adaptation", ref, calc);
  207. }
  208. //---------------------------------------------------------------------------------------------------
  209. //---------------------------------------------------------------------------------------------------
  210. TEST_F(Tests_MatrixClassifier, MDM_Rebias_Save)
  211. {
  212. Geometry::CMatrixClassifierMDMRebias calc;
  213. const Geometry::CMatrixClassifierMDMRebias ref = InitMatrixClassif::MDMRebias::Reference();
  214. EXPECT_TRUE(ref.saveXML("test_MDM_Rebias_Save.xml")) << "Error during Saving : " << std::endl << ref << std::endl;
  215. EXPECT_TRUE(calc.loadXML("test_MDM_Rebias_Save.xml")) << "Error during Loading : " << std::endl << calc << std::endl;
  216. EXPECT_TRUE(ref == calc) << ErrorMsg("MDM Rebias Save", ref, calc);
  217. }
  218. //---------------------------------------------------------------------------------------------------
  219. //---------------------------------------------------------------------------------------------------
  220. TEST_F(Tests_MatrixClassifier, FgMDM_RT_Rebias_Train)
  221. {
  222. Geometry::CMatrixClassifierFgMDMRTRebias calc;
  223. EXPECT_TRUE(calc.train(m_dataSet)) << "Error during Training : " << std::endl << calc << std::endl;
  224. const Geometry::CMatrixClassifierFgMDMRTRebias ref = InitMatrixClassif::FgMDMRTRebias::Reference();
  225. EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Rebias Train", ref, calc); // The mean method is different in matlab toolbox and python toolbox
  226. }
  227. //---------------------------------------------------------------------------------------------------
  228. //---------------------------------------------------------------------------------------------------
  229. TEST_F(Tests_MatrixClassifier, FgMDM_RT_Rebias_Save)
  230. {
  231. Geometry::CMatrixClassifierFgMDMRTRebias calc;
  232. const Geometry::CMatrixClassifierFgMDMRTRebias ref = InitMatrixClassif::FgMDMRTRebias::Reference();
  233. EXPECT_TRUE(ref.saveXML("test_FgMDM_Rebias_Save.xml")) << "Error during Saving : " << std::endl << ref << std::endl;
  234. EXPECT_TRUE(calc.loadXML("test_FgMDM_Rebias_Save.xml")) << "Error during Loading : " << std::endl << calc << std::endl;
  235. EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Rebias Save", ref, calc);
  236. }
  237. //---------------------------------------------------------------------------------------------------
  238. //---------------------------------------------------------------------------------------------------
  239. TEST_F(Tests_MatrixClassifier, FgMDM_RT_Rebias_Classifify)
  240. {
  241. Geometry::CMatrixClassifierFgMDMRTRebias calc = InitMatrixClassif::FgMDMRTRebias::Reference();
  242. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRTRebias::Prediction(), EMPTY_DIST, Geometry::EAdaptations::None);
  243. //const Geometry::CMatrixClassifierFgMDMRTRebias ref = InitMatrixClassif::FgMDMRTRebias::After();
  244. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Rebias Classify Change without adaptation mode", ref, calc);
  245. }
  246. //---------------------------------------------------------------------------------------------------
  247. //---------------------------------------------------------------------------------------------------
  248. TEST_F(Tests_MatrixClassifier, FgMDM_RT_Rebias_Classifify_Adapt_Supervised)
  249. {
  250. Geometry::CMatrixClassifierFgMDMRTRebias calc = InitMatrixClassif::FgMDMRTRebias::Reference();
  251. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRTRebias::PredictionSupervised(), EMPTY_DIST, Geometry::EAdaptations::Supervised);
  252. //const Geometry::CMatrixClassifierFgMDMRTRebias ref = InitMatrixClassif::FgMDMRTRebias::AfterSupervised();
  253. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Rebias Adapt Classify after Supervised adaptation", ref, calc);
  254. }
  255. //---------------------------------------------------------------------------------------------------
  256. //---------------------------------------------------------------------------------------------------
  257. TEST_F(Tests_MatrixClassifier, FgMDM_RT_Rebias_Classifify_Adapt_Unsupervised)
  258. {
  259. Geometry::CMatrixClassifierFgMDMRTRebias calc = InitMatrixClassif::FgMDMRTRebias::Reference();
  260. TestClassify(calc, m_dataSet, InitMatrixClassif::FgMDMRTRebias::PredictionUnSupervised(), EMPTY_DIST, Geometry::EAdaptations::Unsupervised);
  261. //const Geometry::CMatrixClassifierFgMDMRTRebias ref = InitMatrixClassif::FgMDMRTRebias::AfterUnSupervised();
  262. //EXPECT_TRUE(ref == calc) << ErrorMsg("FgMDM Rebias Adapt Classify after Unsupervised adaptation", ref, calc);
  263. }
  264. //---------------------------------------------------------------------------------------------------