You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ParsedExpression.cpp 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /* -------------------------------------------------------------------------- *
  2. * Lepton *
  3. * -------------------------------------------------------------------------- *
  4. * This is part of the Lepton expression parser originating from *
  5. * Simbios, the NIH National Center for Physics-Based Simulation of *
  6. * Biological Structures at Stanford, funded under the NIH Roadmap for *
  7. * Medical Research, grant U54 GM072970. See https://simtk.org. *
  8. * *
  9. * Portions copyright (c) 2009 Stanford University and the Authors. *
  10. * Authors: Peter Eastman *
  11. * Contributors: *
  12. * *
  13. * Permission is hereby granted, free of charge, to any person obtaining a *
  14. * copy of this software and associated documentation files (the "Software"), *
  15. * to deal in the Software without restriction, including without limitation *
  16. * the rights to use, copy, modify, merge, publish, distribute, sublicense, *
  17. * and/or sell copies of the Software, and to permit persons to whom the *
  18. * Software is furnished to do so, subject to the following conditions: *
  19. * *
  20. * The above copyright notice and this permission notice shall be included in *
  21. * all copies or substantial portions of the Software. *
  22. * *
  23. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
  24. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
  25. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
  26. * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
  27. * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
  28. * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
  29. * USE OR OTHER DEALINGS IN THE SOFTWARE. *
  30. * -------------------------------------------------------------------------- */
  31. #include "ParsedExpression.h"
  32. #include "CompiledExpression.h"
  33. #include "ExpressionProgram.h"
  34. #include "Operation.h"
  35. #include <limits>
  36. #include <vector>
  37. using namespace Lepton;
  38. ParsedExpression::ParsedExpression() : rootNode(ExpressionTreeNode()) {}
  39. ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) {}
  40. const ExpressionTreeNode& ParsedExpression::getRootNode() const { return rootNode; }
  41. double ParsedExpression::evaluate() const { return evaluate(getRootNode(), std::map<std::string, double>()); }
  42. double ParsedExpression::evaluate(const std::map<std::string, double>& variables) const { return evaluate(getRootNode(), variables); }
  43. double ParsedExpression::evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables)
  44. {
  45. size_t numArgs = node.getChildren().size();
  46. std::vector<double> args(std::max(numArgs, size_t(1)));
  47. for (size_t i = 0; i < numArgs; ++i) { args[i] = evaluate(node.getChildren()[i], variables); }
  48. return node.getOperation().evaluate(&args[0], variables);
  49. }
  50. ParsedExpression ParsedExpression::optimize() const
  51. {
  52. ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
  53. while (true)
  54. {
  55. ExpressionTreeNode simplified = substituteSimplerExpression(result);
  56. if (simplified == result) { break; }
  57. result = simplified;
  58. }
  59. return ParsedExpression(result);
  60. }
  61. ParsedExpression ParsedExpression::optimize(const std::map<std::string, double>& variables) const
  62. {
  63. ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
  64. result = precalculateConstantSubexpressions(result);
  65. while (true)
  66. {
  67. ExpressionTreeNode simplified = substituteSimplerExpression(result);
  68. if (simplified == result) { break; }
  69. result = simplified;
  70. }
  71. return ParsedExpression(result);
  72. }
  73. ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables)
  74. {
  75. if (node.getOperation().getId() == Operation::VARIABLE)
  76. {
  77. const Operation::Variable& var = dynamic_cast<const Operation::Variable&>(node.getOperation());
  78. auto iter = variables.find(var.getName());
  79. if (iter == variables.end()) { return node; }
  80. return ExpressionTreeNode(new Operation::Constant(iter->second));
  81. }
  82. std::vector<ExpressionTreeNode> children(node.getChildren().size());
  83. for (size_t i = 0; i < children.size(); ++i) { children[i] = preevaluateVariables(node.getChildren()[i], variables); }
  84. return ExpressionTreeNode(node.getOperation().clone(), children);
  85. }
  86. ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node)
  87. {
  88. std::vector<ExpressionTreeNode> children(node.getChildren().size());
  89. for (size_t i = 0; i < children.size(); ++i) { children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); }
  90. ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
  91. if (node.getOperation().getId() == Operation::VARIABLE) { return result; }
  92. for (size_t i = 0; i < children.size(); ++i) { if (children[i].getOperation().getId() != Operation::CONSTANT) { return result; } }
  93. return ExpressionTreeNode(new Operation::Constant(evaluate(result, std::map<std::string, double>())));
  94. }
  95. ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node)
  96. {
  97. std::vector<ExpressionTreeNode> childs(node.getChildren().size());
  98. for (size_t i = 0; i < childs.size(); ++i) { childs[i] = substituteSimplerExpression(node.getChildren()[i]); }
  99. Operation::Id op1 = childs[0].getOperation().getId();
  100. Operation::Id op2 = childs[1].getOperation().getId();
  101. switch (node.getOperation().getId())
  102. {
  103. case Operation::ADD:
  104. {
  105. const double first = getConstantValue(childs[0]);
  106. const double second = getConstantValue(childs[1]);
  107. if (first == 0.0) { return childs[1]; } // Add 0
  108. if (second == 0.0) { return childs[0]; } // Add 0
  109. if (first == first) { return ExpressionTreeNode(new Operation::AddConstant(first), childs[1]); } // Add a constant
  110. if (second == second) { return ExpressionTreeNode(new Operation::AddConstant(second), childs[0]); } // Add a constant
  111. if (op2 == Operation::NEGATE) { return ExpressionTreeNode(new Operation::Subtract(), childs[0], childs[1].getChildren()[0]); } // a+(-b) = a-b
  112. if (op1 == Operation::NEGATE) { return ExpressionTreeNode(new Operation::Subtract(), childs[1], childs[0].getChildren()[0]); } // (-a)+b = b-a
  113. break;
  114. }
  115. case Operation::SUBTRACT:
  116. {
  117. if (childs[0] == childs[1]) { return ExpressionTreeNode(new Operation::Constant(0.0)); } // Subtracting anything from itself is 0
  118. const double first = getConstantValue(childs[0]);
  119. if (first == 0.0) { return ExpressionTreeNode(new Operation::Negate(), childs[1]); } // Subtract from 0
  120. const double second = getConstantValue(childs[1]);
  121. if (second == 0.0) { return childs[0]; } // Subtract 0
  122. if (second == second) { return ExpressionTreeNode(new Operation::AddConstant(-second), childs[0]); } // Subtract a constant
  123. if (op2 == Operation::NEGATE) { return ExpressionTreeNode(new Operation::Add(), childs[0], childs[1].getChildren()[0]); } // a-(-b) = a+b
  124. break;
  125. }
  126. case Operation::MULTIPLY:
  127. {
  128. double first = getConstantValue(childs[0]);
  129. double second = getConstantValue(childs[1]);
  130. if (first == 0.0 || second == 0.0) { return ExpressionTreeNode(new Operation::Constant(0.0)); } // Multiply by 0
  131. if (first == 1.0) { return childs[1]; } // Multiply by 1
  132. if (second == 1.0) { return childs[0]; } // Multiply by 1
  133. if (op1 == Operation::CONSTANT)
  134. { // Multiply by a constant
  135. if (op2 == Operation::MULTIPLY_CONSTANT)
  136. { // Combine two multiplies into a single one
  137. return ExpressionTreeNode(
  138. new Operation::MultiplyConstant(first * dynamic_cast<const Operation::MultiplyConstant*>(&childs[1].getOperation())->getValue()),
  139. childs[1].getChildren()[0]);
  140. }
  141. return ExpressionTreeNode(new Operation::MultiplyConstant(first), childs[1]);
  142. }
  143. if (op2 == Operation::CONSTANT)
  144. { // Multiply by a constant
  145. if (op1 == Operation::MULTIPLY_CONSTANT)
  146. { // Combine two multiplies into a single one
  147. return ExpressionTreeNode(
  148. new Operation::MultiplyConstant(second * dynamic_cast<const Operation::MultiplyConstant*>(&childs[0].getOperation())->getValue()),
  149. childs[0].getChildren()[0]);
  150. }
  151. return ExpressionTreeNode(new Operation::MultiplyConstant(second), childs[0]);
  152. }
  153. if (op1 == Operation::NEGATE && op2 == Operation::NEGATE)
  154. { // The two negations cancel
  155. return ExpressionTreeNode(new Operation::Multiply(), childs[0].getChildren()[0], childs[1].getChildren()[0]);
  156. }
  157. if (op1 == Operation::NEGATE && op2 == Operation::MULTIPLY_CONSTANT)
  158. { // Negate the constant
  159. return ExpressionTreeNode(new Operation::Multiply(), childs[0].getChildren()[0],
  160. ExpressionTreeNode(
  161. new Operation::MultiplyConstant(
  162. -dynamic_cast<const Operation::MultiplyConstant*>(&childs[1].getOperation())->getValue()),
  163. childs[1].getChildren()[0]));
  164. }
  165. if (op2 == Operation::NEGATE && op1 == Operation::MULTIPLY_CONSTANT)
  166. { // Negate the constant
  167. return ExpressionTreeNode(new Operation::Multiply(),
  168. ExpressionTreeNode(
  169. new Operation::MultiplyConstant(
  170. -dynamic_cast<const Operation::MultiplyConstant*>(&childs[0].getOperation())->getValue()),
  171. childs[0].getChildren()[0]), childs[1].getChildren()[0]);
  172. }
  173. if (op1 == Operation::NEGATE)
  174. { // Pull the negation out so it can possibly be optimized further
  175. return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), childs[0].getChildren()[0], childs[1]));
  176. }
  177. if (op2 == Operation::NEGATE)
  178. { // Pull the negation out so it can possibly be optimized further
  179. return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), childs[0], childs[1].getChildren()[0]));
  180. }
  181. if (op2 == Operation::RECIPROCAL) { return ExpressionTreeNode(new Operation::Divide(), childs[0], childs[1].getChildren()[0]); } // a*(1/b) = a/b
  182. if (op1 == Operation::RECIPROCAL) { return ExpressionTreeNode(new Operation::Divide(), childs[1], childs[0].getChildren()[0]); } // (1/a)*b = b/a
  183. if (childs[0] == childs[1]) { return ExpressionTreeNode(new Operation::Square(), childs[0]); } // x*x = square(x)
  184. if (op1 == Operation::SQUARE && childs[0].getChildren()[0] == childs[1]) { return ExpressionTreeNode(new Operation::Cube(), childs[1]); } // x^3
  185. if (op2 == Operation::SQUARE && childs[1].getChildren()[0] == childs[0]) { return ExpressionTreeNode(new Operation::Cube(), childs[0]); } // x^3
  186. break;
  187. }
  188. case Operation::DIVIDE:
  189. {
  190. if (childs[0] == childs[1]) { return ExpressionTreeNode(new Operation::Constant(1.0)); } // Dividing anything from itself is 0
  191. const double numerator = getConstantValue(childs[0]);
  192. if (numerator == 0.0) { return ExpressionTreeNode(new Operation::Constant(0.0)); } // 0 divided by something
  193. if (numerator == 1.0) { return ExpressionTreeNode(new Operation::Reciprocal(), childs[1]); } // 1 divided by something
  194. const double denominator = getConstantValue(childs[1]);
  195. if (denominator == 1.0) { return childs[0]; } // Divide by 1
  196. if (op2 == Operation::CONSTANT)
  197. {
  198. if (op1 == Operation::MULTIPLY_CONSTANT)
  199. { // Combine a multiply and a divide into one multiply
  200. return ExpressionTreeNode(new Operation::MultiplyConstant
  201. (dynamic_cast<const Operation::MultiplyConstant*>(&childs[0].getOperation())->getValue() / denominator),
  202. childs[0].getChildren()[0]);
  203. }
  204. return ExpressionTreeNode(new Operation::MultiplyConstant(1.0 / denominator), childs[0]); // Replace a divide with a multiply
  205. }
  206. if (op1 == Operation::NEGATE && op2 == Operation::NEGATE)
  207. { // The two negations cancel
  208. return ExpressionTreeNode(new Operation::Divide(), childs[0].getChildren()[0], childs[1].getChildren()[0]);
  209. }
  210. if (op2 == Operation::NEGATE && op1 == Operation::MULTIPLY_CONSTANT)
  211. { // Negate the constant
  212. return ExpressionTreeNode(new Operation::Divide(), ExpressionTreeNode(
  213. new Operation::MultiplyConstant(
  214. -dynamic_cast<const Operation::MultiplyConstant*>(&childs[0].getOperation())->getValue()),
  215. childs[0].getChildren()[0]), childs[1].getChildren()[0]);
  216. }
  217. if (op1 == Operation::NEGATE)
  218. { // Pull the negation out so it can possibly be optimized further
  219. return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), childs[0].getChildren()[0], childs[1]));
  220. }
  221. if (op2 == Operation::NEGATE)
  222. { // Pull the negation out so it can possibly be optimized further
  223. return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), childs[0], childs[1].getChildren()[0]));
  224. }
  225. if (childs[1].getOperation().getId() == Operation::RECIPROCAL)
  226. { // a/(1/b) = a*b
  227. return ExpressionTreeNode(new Operation::Multiply(), childs[0], childs[1].getChildren()[0]);
  228. }
  229. break;
  230. }
  231. case Operation::POWER:
  232. {
  233. double base = getConstantValue(childs[0]);
  234. if (base == 0.0) { return ExpressionTreeNode(new Operation::Constant(0.0)); } // 0 to any power is 0
  235. if (base == 1.0) { return ExpressionTreeNode(new Operation::Constant(1.0)); } // 1 to any power is 1
  236. double exponent = getConstantValue(childs[1]);
  237. if (exponent == 0.0) { return ExpressionTreeNode(new Operation::Constant(1.0)); } // x^0 = 1
  238. if (exponent == 1.0) { return childs[0]; } // x^1 = x
  239. if (exponent == -1.0) { return ExpressionTreeNode(new Operation::Reciprocal(), childs[0]); } // x^-1 = recip(x)
  240. if (exponent == 2.0) { return ExpressionTreeNode(new Operation::Square(), childs[0]); } // x^2 = square(x)
  241. if (exponent == 3.0) { return ExpressionTreeNode(new Operation::Cube(), childs[0]); } // x^3 = cube(x)
  242. if (exponent == 0.5) { return ExpressionTreeNode(new Operation::Sqrt(), childs[0]); } // x^0.5 = sqrt(x)
  243. if (exponent == exponent) { return ExpressionTreeNode(new Operation::PowerConstant(exponent), childs[0]); } // Constant power
  244. break;
  245. }
  246. case Operation::NEGATE:
  247. {
  248. if (op1 == Operation::MULTIPLY_CONSTANT)
  249. { // Combine a multiply and a negate into a single multiply
  250. return ExpressionTreeNode(
  251. new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&childs[0].getOperation())->getValue()),
  252. childs[0].getChildren()[0]);
  253. }
  254. if (op1 == Operation::CONSTANT) { return ExpressionTreeNode(new Operation::Constant(-getConstantValue(childs[0]))); } // Negate a constant
  255. if (op1 == Operation::NEGATE) { return childs[0].getChildren()[0]; } // The two negations cancel
  256. break;
  257. }
  258. case Operation::MULTIPLY_CONSTANT:
  259. {
  260. if (op1 == Operation::MULTIPLY_CONSTANT)
  261. { // Combine two multiplies into a single one
  262. return ExpressionTreeNode(
  263. new Operation::MultiplyConstant(
  264. dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue() * dynamic_cast<const Operation::MultiplyConstant*>(&
  265. childs[0].getOperation())->getValue()), childs[0].getChildren()[0]);
  266. }
  267. if (op1 == Operation::CONSTANT)
  268. { // Multiply two constants
  269. return ExpressionTreeNode(
  270. new Operation::Constant(
  271. dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue() * getConstantValue(childs[0])));
  272. }
  273. if (op1 == Operation::NEGATE)
  274. { // Combine a multiply and a negate into a single multiply
  275. return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()),
  276. childs[0].getChildren()[0]);
  277. }
  278. break;
  279. }
  280. default:
  281. {
  282. // If operation ID is not one of the above,
  283. // we don't substitute a simpler expression.
  284. break;
  285. }
  286. }
  287. return ExpressionTreeNode(node.getOperation().clone(), childs);
  288. }
  289. ParsedExpression ParsedExpression::differentiate(const std::string& variable) const { return differentiate(getRootNode(), variable); }
  290. ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const std::string& variable)
  291. {
  292. std::vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
  293. for (size_t i = 0; i < childDerivs.size(); ++i) { childDerivs[i] = differentiate(node.getChildren()[i], variable); }
  294. return node.getOperation().differentiate(node.getChildren(), childDerivs, variable);
  295. }
  296. double ParsedExpression::getConstantValue(const ExpressionTreeNode& node)
  297. {
  298. if (node.getOperation().getId() == Operation::CONSTANT) { return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue(); }
  299. return std::numeric_limits<double>::quiet_NaN();
  300. }
  301. ExpressionProgram ParsedExpression::createProgram() const { return ExpressionProgram(*this); }
  302. CompiledExpression ParsedExpression::createCompiledExpression() const { return CompiledExpression(*this); }
  303. ParsedExpression ParsedExpression::renameVariables(const std::map<std::string, std::string>& replacements) const
  304. {
  305. return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
  306. }
  307. ExpressionTreeNode ParsedExpression::renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements)
  308. {
  309. if (node.getOperation().getId() == Operation::VARIABLE)
  310. {
  311. auto replace = replacements.find(node.getOperation().getName());
  312. if (replace != replacements.end()) { return ExpressionTreeNode(new Operation::Variable(replace->second)); }
  313. }
  314. std::vector<ExpressionTreeNode> children;
  315. for (size_t i = 0; i < node.getChildren().size(); ++i) { children.push_back(renameNodeVariables(node.getChildren()[i], replacements)); }
  316. return ExpressionTreeNode(node.getOperation().clone(), children);
  317. }
  318. std::ostream& Lepton::operator<<(std::ostream& out, const ExpressionTreeNode& node)
  319. {
  320. if (node.getOperation().isInfixOperator() && node.getChildren().size() == 2)
  321. {
  322. out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName() << "(" << node.getChildren()[1] << ")";
  323. }
  324. else if (node.getOperation().isInfixOperator() && node.getChildren().size() == 1)
  325. {
  326. out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName();
  327. }
  328. else
  329. {
  330. out << node.getOperation().getName();
  331. if (!node.getChildren().empty())
  332. {
  333. out << "(";
  334. for (size_t i = 0; i < node.getChildren().size(); ++i)
  335. {
  336. if (i > 0) { out << ", "; }
  337. out << node.getChildren()[i];
  338. }
  339. out << ")";
  340. }
  341. }
  342. return out;
  343. }
  344. std::ostream& Lepton::operator<<(std::ostream& out, const ParsedExpression& exp)
  345. {
  346. out << exp.getRootNode();
  347. return out;
  348. }