You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ovpCBoxAlgorithmXDAWNTrainer.cpp 8.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #include "ovpCBoxAlgorithmXDAWNTrainer.h"
  2. #include "fs/Files.h"
  3. #include <cstdio>
  4. #include <iostream>
  5. namespace OpenViBE {
  6. namespace Plugins {
  7. namespace SignalProcessing {
  8. CBoxAlgorithmXDAWNTrainer::CBoxAlgorithmXDAWNTrainer() {}
  9. bool CBoxAlgorithmXDAWNTrainer::initialize()
  10. {
  11. m_trainStimulationID = FSettingValueAutoCast(*this->getBoxAlgorithmContext(), 0);
  12. m_filterFilename = FSettingValueAutoCast(*this->getBoxAlgorithmContext(), 1);
  13. OV_ERROR_UNLESS_KRF(m_filterFilename.length() != 0, "The filter filename is empty.\n", Kernel::ErrorType::BadSetting);
  14. if (FS::Files::fileExists(m_filterFilename))
  15. {
  16. FILE* file = FS::Files::open(m_filterFilename, "wt");
  17. OV_ERROR_UNLESS_KRF(file != nullptr, "The filter file exists but cannot be used.\n", Kernel::ErrorType::BadFileRead);
  18. fclose(file);
  19. }
  20. const int filterDimension = FSettingValueAutoCast(*this->getBoxAlgorithmContext(), 2);
  21. OV_ERROR_UNLESS_KRF(filterDimension > 0, "The dimension of the filter must be strictly positive.\n", Kernel::ErrorType::OutOfBound);
  22. m_filterDim = size_t(filterDimension);
  23. m_saveAsBoxConfig = FSettingValueAutoCast(*this->getBoxAlgorithmContext(), 3);
  24. m_stimDecoder.initialize(*this, 0);
  25. m_signalDecoder[0].initialize(*this, 1);
  26. m_signalDecoder[1].initialize(*this, 2);
  27. m_stimEncoder.initialize(*this, 0);
  28. return true;
  29. }
  30. bool CBoxAlgorithmXDAWNTrainer::uninitialize()
  31. {
  32. m_stimDecoder.uninitialize();
  33. m_signalDecoder[0].uninitialize();
  34. m_signalDecoder[1].uninitialize();
  35. m_stimEncoder.uninitialize();
  36. return true;
  37. }
  38. bool CBoxAlgorithmXDAWNTrainer::processInput(const size_t index)
  39. {
  40. if (index == 0) { this->getBoxAlgorithmContext()->markAlgorithmAsReadyToProcess(); }
  41. return true;
  42. }
  43. bool CBoxAlgorithmXDAWNTrainer::process()
  44. {
  45. Kernel::IBoxIO& dynamicBoxContext = this->getDynamicBoxContext();
  46. bool train = false;
  47. for (size_t i = 0; i < dynamicBoxContext.getInputChunkCount(0); ++i)
  48. {
  49. m_stimEncoder.getInputStimulationSet()->clear();
  50. m_stimDecoder.decode(i);
  51. if (m_stimDecoder.isHeaderReceived()) { m_stimEncoder.encodeHeader(); }
  52. if (m_stimDecoder.isBufferReceived())
  53. {
  54. for (size_t j = 0; j < m_stimDecoder.getOutputStimulationSet()->getStimulationCount(); ++j)
  55. {
  56. const uint64_t stimulationId = m_stimDecoder.getOutputStimulationSet()->getStimulationIdentifier(j);
  57. if (stimulationId == m_trainStimulationID)
  58. {
  59. train = true;
  60. m_stimEncoder.getInputStimulationSet()->appendStimulation(
  61. OVTK_StimulationId_TrainCompleted, m_stimDecoder.getOutputStimulationSet()->getStimulationDate(j), 0);
  62. }
  63. }
  64. m_stimEncoder.encodeBuffer();
  65. }
  66. if (m_stimDecoder.isEndReceived()) { m_stimEncoder.encodeEnd(); }
  67. dynamicBoxContext.markOutputAsReadyToSend(0, dynamicBoxContext.getInputChunkStartTime(0, i), dynamicBoxContext.getInputChunkEndTime(0, i));
  68. }
  69. if (train)
  70. {
  71. std::vector<size_t> erpSampleIndexes;
  72. std::array<Eigen::MatrixXd, 2> X; // X[0] is session matrix, X[1] is averaged ERP
  73. std::array<Eigen::MatrixXd, 2> C; // Covariance matrices
  74. std::array<size_t, 2> n;
  75. size_t nChannel = 0;
  76. this->getLogManager() << Kernel::LogLevel_Info << "Received train stimulation...\n";
  77. // Decodes input signals
  78. for (size_t j = 0; j < 2; ++j)
  79. {
  80. n[j] = 0;
  81. for (size_t i = 0; i < dynamicBoxContext.getInputChunkCount(j + 1); ++i)
  82. {
  83. Toolkit::TSignalDecoder<CBoxAlgorithmXDAWNTrainer>& decoder = m_signalDecoder[j];
  84. decoder.decode(i);
  85. CMatrix* matrix = decoder.getOutputMatrix();
  86. nChannel = matrix->getDimensionSize(0);
  87. const size_t nSample = matrix->getDimensionSize(1);
  88. const size_t sampling = size_t(decoder.getOutputSamplingRate());
  89. if (decoder.isHeaderReceived())
  90. {
  91. OV_ERROR_UNLESS_KRF(sampling > 0, "Input sampling frequency is equal to 0. Plugin can not process.\n", Kernel::ErrorType::OutOfBound);
  92. OV_ERROR_UNLESS_KRF(nChannel > 0, "For condition " << j + 1 << " got no channel in signal stream.\n", Kernel::ErrorType::OutOfBound);
  93. OV_ERROR_UNLESS_KRF(nSample > 0, "For condition " << j + 1 << " got no samples in signal stream.\n", Kernel::ErrorType::OutOfBound);
  94. OV_ERROR_UNLESS_KRF(m_filterDim <= nChannel, "The filter dimension must not be superior than the channel count.\n", Kernel::ErrorType::OutOfBound);
  95. if (!n[0]) // Initialize signal buffer (X[0]) only when receiving input signal header.
  96. {
  97. X[j].resize(nChannel, (dynamicBoxContext.getInputChunkCount(j + 1) - 1) * nSample);
  98. }
  99. else // otherwise, only ERP averaging buffer (X[1]) is reset
  100. {
  101. X[j] = Eigen::MatrixXd::Zero(nChannel, nSample);
  102. }
  103. }
  104. if (decoder.isBufferReceived())
  105. {
  106. Eigen::MatrixXd A = Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
  107. matrix->getBuffer(), nChannel, nSample);
  108. switch (j)
  109. {
  110. case 0: // Session
  111. X[j].block(0, n[j] * A.cols(), A.rows(), A.cols()) = A;
  112. break;
  113. case 1: // ERP
  114. X[j] = X[j] + A; // Computes sumed ERP
  115. // $$$ Assumes continuous session signal starting at date 0
  116. {
  117. size_t ERPSampleIndex = size_t(((dynamicBoxContext.getInputChunkStartTime(j + 1, i) >> 16) * sampling) >> 16);
  118. erpSampleIndexes.push_back(ERPSampleIndex);
  119. }
  120. break;
  121. default:
  122. break;
  123. }
  124. n[j]++;
  125. }
  126. #if 0
  127. if (decoder.isEndReceived())
  128. {
  129. }
  130. #endif
  131. }
  132. OV_ERROR_UNLESS_KRF(n[j] != 0, "Did not have input signal for condition " << j + 1 << "\n", Kernel::ErrorType::BadValue);
  133. switch (j)
  134. {
  135. case 0: // Session
  136. break;
  137. case 1: // ERP
  138. X[j] = X[j] / double(n[j]); // Averages ERP
  139. break;
  140. default:
  141. break;
  142. }
  143. }
  144. // We need equal number of channels
  145. OV_ERROR_UNLESS_KRF(X[0].rows() == X[1].rows(),
  146. "Dimension mismatch, first input had " << size_t(X[0].rows()) << " channels while second input had " << size_t(X[1].rows()) <<
  147. " channels\n",
  148. Kernel::ErrorType::BadValue);
  149. // Grabs usefull values
  150. const size_t sampleCountSession = X[0].cols();
  151. const size_t sampleCountERP = X[1].cols();
  152. // Now we compute matrix D
  153. const Eigen::MatrixXd DI = Eigen::MatrixXd::Identity(sampleCountERP, sampleCountERP);
  154. Eigen::MatrixXd D = Eigen::MatrixXd::Zero(sampleCountERP, sampleCountSession);
  155. for (size_t sampleIndex : erpSampleIndexes) { D.block(0, sampleIndex, sampleCountERP, sampleCountERP) += DI; }
  156. // Computes covariance matrices
  157. C[0] = X[0] * X[0].transpose();
  158. C[1] = /*Y * Y.transpose();*/ X[1] * /* D.transpose() * */ (D * D.transpose()).fullPivLu().inverse() /* * D */ * X[1].transpose();
  159. // Solves generalized eigen decomposition
  160. const Eigen::GeneralizedSelfAdjointEigenSolver<Eigen::MatrixXd> eigenSolver(C[0].selfadjointView<Eigen::Lower>(), C[1].selfadjointView<Eigen::Lower>());
  161. if (eigenSolver.info() != Eigen::Success)
  162. {
  163. const enum Eigen::ComputationInfo error = eigenSolver.info();
  164. const char* errorMessage = "unknown";
  165. switch (error)
  166. {
  167. case Eigen::NumericalIssue: errorMessage = "Numerical issue";
  168. break;
  169. case Eigen::NoConvergence: errorMessage = "No convergence";
  170. break;
  171. // case Eigen::InvalidInput: errorMessage="Invalid input"; break; // FIXME
  172. default: break;
  173. }
  174. OV_ERROR_KRF("Could not solve generalized eigen decomposition, got error[" << CString(errorMessage) << "]\n",
  175. Kernel::ErrorType::BadProcessing);
  176. }
  177. // Create a CMatrix mapper that can spool the filters to a file
  178. CMatrix eigenVectors;
  179. eigenVectors.resize(m_filterDim, nChannel);
  180. Eigen::Map<MatrixXdRowMajor> vectorsMapper(eigenVectors.getBuffer(), m_filterDim, nChannel);
  181. vectorsMapper.block(0, 0, m_filterDim, nChannel) = eigenSolver.eigenvectors().block(0, 0, nChannel, m_filterDim).transpose();
  182. // Saves filters
  183. FILE* file = FS::Files::open(m_filterFilename.toASCIIString(), "wt");
  184. OV_ERROR_UNLESS_KRF(file != nullptr, "Could not open file [" << m_filterFilename << "] for writing.\n", Kernel::ErrorType::BadFileWrite);
  185. if (m_saveAsBoxConfig)
  186. {
  187. fprintf(file, "<OpenViBE-SettingsOverride>\n");
  188. fprintf(file, "\t<SettingValue>");
  189. for (size_t i = 0; i < eigenVectors.getBufferElementCount(); ++i) { fprintf(file, "%e ", eigenVectors.getBuffer()[i]); }
  190. fprintf(file, "</SettingValue>\n");
  191. fprintf(file, "\t<SettingValue>%u</SettingValue>\n", (unsigned int)m_filterDim);
  192. fprintf(file, "\t<SettingValue>%u</SettingValue>\n", (unsigned int)nChannel);
  193. fprintf(file, "\t<SettingValue></SettingValue>\n");
  194. fprintf(file, "</OpenViBE-SettingsOverride>");
  195. }
  196. else
  197. {
  198. OV_ERROR_UNLESS_KRF(Toolkit::Matrix::saveToTextFile(eigenVectors, m_filterFilename),
  199. "Unable to save to [" << m_filterFilename << "]\n", Kernel::ErrorType::BadFileWrite);
  200. }
  201. OV_WARNING_UNLESS_K(fclose(file) == 0, "Could not close file[" << m_filterFilename << "].\n");
  202. this->getLogManager() << Kernel::LogLevel_Info << "Training finished and saved to [" << m_filterFilename << "]!\n";
  203. }
  204. return true;
  205. }
  206. } // namespace SignalProcessing
  207. } // namespace Plugins
  208. } // namespace OpenViBE