You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Classification.cpp 3.5KB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #include "geometry/Classification.hpp"
  2. #include "geometry/Covariance.hpp"
  3. #include "geometry/Basics.hpp"
  4. namespace Geometry {
  5. ///-------------------------------------------------------------------------------------------------
  6. bool LSQR(const std::vector<std::vector<Eigen::RowVectorXd>>& dataset, Eigen::MatrixXd& weight)
  7. {
  8. // Precomputation
  9. if (dataset.empty()) { return false; }
  10. const size_t nbClass = dataset.size(), nbFeatures = dataset[0][0].size();
  11. std::vector<size_t> nbSample(nbClass);
  12. size_t totalSample = 0;
  13. for (size_t k = 0; k < nbClass; ++k)
  14. {
  15. if (dataset[k].empty()) { return false; }
  16. nbSample[k] = dataset[k].size();
  17. totalSample += nbSample[k];
  18. }
  19. // Compute Class Euclidian mean
  20. Eigen::MatrixXd mean = Eigen::MatrixXd::Zero(nbClass, nbFeatures);
  21. for (size_t k = 0; k < nbClass; ++k)
  22. {
  23. for (size_t i = 0; i < nbSample[k]; ++i) { mean.row(k) += dataset[k][i]; }
  24. mean.row(k) /= double(nbSample[k]);
  25. }
  26. // Compute Class Covariance
  27. Eigen::MatrixXd cov = Eigen::MatrixXd::Zero(nbFeatures, nbFeatures);
  28. for (size_t k = 0; k < nbClass; ++k)
  29. {
  30. //Fit Data to existing covariance matrix method
  31. Eigen::MatrixXd classData(nbFeatures, nbSample[k]);
  32. for (size_t i = 0; i < nbSample[k]; ++i) { classData.col(i) = dataset[k][i]; }
  33. // Standardize Features
  34. Eigen::RowVectorXd scale;
  35. MatrixStandardScaler(classData, scale);
  36. //Compute Covariance of this class
  37. Eigen::MatrixXd classCov;
  38. if (!CovarianceMatrix(classData, classCov, EEstimator::LWF)) { return false; }
  39. // Rescale
  40. for (size_t i = 0; i < nbFeatures; ++i) { for (size_t j = 0; j < nbFeatures; ++j) { classCov(i, j) *= scale[i] * scale[j]; } }
  41. //Add to cov with good weight
  42. cov += (double(nbSample[k]) / double(totalSample)) * classCov;
  43. }
  44. // linear least squares systems solver
  45. // Chosen solver with the performance table of this page : https://eigen.tuxfamily.org/dox/group__TutorialLinearAlgebra.html
  46. weight = cov.colPivHouseholderQr().solve(mean.transpose()).transpose();
  47. //weight = cov.completeOrthogonalDecomposition().solve(mean.transpose()).transpose();
  48. //weight = cov.bdcSvd(ComputeThinU | ComputeThinV).solve(mean.transpose()).transpose();
  49. // Treat binary case as a special case
  50. if (nbClass == 2)
  51. {
  52. const Eigen::MatrixXd tmp = weight.row(1) - weight.row(0); // Need to use a tmp variable otherwise sometimes error
  53. weight = tmp;
  54. }
  55. return true;
  56. }
  57. ///-------------------------------------------------------------------------------------------------
  58. ///-------------------------------------------------------------------------------------------------
  59. bool FgDACompute(const std::vector<std::vector<Eigen::RowVectorXd>>& dataset, Eigen::MatrixXd& weight)
  60. {
  61. // Compute LSQR Weight
  62. Eigen::MatrixXd w;
  63. if (!LSQR(dataset, w)) { return false; }
  64. const size_t nbClass = w.rows();
  65. // Transform to FgDA Weight
  66. const Eigen::MatrixXd wT = w.transpose();
  67. weight = (wT * (w * wT).colPivHouseholderQr().solve(Eigen::MatrixXd::Identity(nbClass, nbClass))) * w;
  68. return true;
  69. }
  70. ///-------------------------------------------------------------------------------------------------
  71. ///-------------------------------------------------------------------------------------------------
  72. bool FgDAApply(const Eigen::RowVectorXd& in, Eigen::RowVectorXd& out, const Eigen::MatrixXd& weight)
  73. {
  74. if (in.cols() != weight.rows()) { return false; }
  75. out = in * weight;
  76. return true;
  77. }
  78. ///-------------------------------------------------------------------------------------------------
  79. } // namespace Geometry