You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

svm.cpp 69KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <ctype.h>
  5. #include <float.h>
  6. #include <string.h>
  7. #include <stdarg.h>
  8. #include <limits.h>
  9. #include <locale.h>
  10. #include "svm.h"
  11. int libsvm_version = LIBSVM_VERSION;
  12. typedef float Qfloat;
  13. typedef signed char schar;
  14. #ifndef min
  15. template <class T>
  16. static T min(const T x, const T y) { return (x < y) ? x : y; }
  17. #endif
  18. #ifndef max
  19. template <class T>
  20. static T max(const T x, const T y) { return (x > y) ? x : y; }
  21. #endif
  22. template <class T>
  23. static void swap(T& x, T& y)
  24. {
  25. T t = x;
  26. x = y;
  27. y = t;
  28. }
  29. template <class S, class T>
  30. static void clone(T*& dst, S* src, const int n)
  31. {
  32. dst = new T[n];
  33. memcpy((void*)dst, (void*)src, sizeof(T) * n);
  34. }
  35. static double powi(const double base, const int times)
  36. {
  37. double tmp = base, ret = 1.0;
  38. for (int t = times; t > 0; t /= 2)
  39. {
  40. if (t % 2 == 1) { ret *= tmp; }
  41. tmp = tmp * tmp;
  42. }
  43. return ret;
  44. }
  45. #define INF HUGE_VAL
  46. #define TAU 1e-12
  47. #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
  48. static void print_string_stdout(const char* s)
  49. {
  50. fputs(s, stdout);
  51. fflush(stdout);
  52. }
  53. static void (*svm_print_string)(const char*) = &print_string_stdout;
  54. #if 0
  55. static void info(const char* fmt, ...)
  56. {
  57. char buf[BUFSIZ];
  58. va_list ap;
  59. va_start(ap, fmt);
  60. vsprintf(buf, fmt, ap);
  61. va_end(ap);
  62. (*svm_print_string)(buf);
  63. }
  64. #else
  65. static void info(const char* /*fmt*/, ...) {}
  66. #endif
  67. //
  68. // Kernel Cache
  69. //
  70. // l is the number of total data items
  71. // size is the cache size limit in bytes
  72. //
  73. class Cache
  74. {
  75. public:
  76. Cache(const int l_, const long int size_);
  77. ~Cache();
  78. // request data [0,len)
  79. // return some position p where [p,len) need to be filled
  80. // (p >= len if nothing needs to be filled)
  81. int get_data(const int index, Qfloat** data, int len);
  82. void swap_index(int i, int j);
  83. private:
  84. int l;
  85. long int size;
  86. struct head_t
  87. {
  88. head_t *prev, *next; // a circular list
  89. Qfloat* data;
  90. int len; // data[0,len) is cached in this entry
  91. };
  92. head_t* head;
  93. head_t lru_head;
  94. void lru_delete(head_t* h);
  95. void lru_insert(head_t* h);
  96. };
  97. Cache::Cache(const int l_, const long int size_) : l(l_), size(size_)
  98. {
  99. head = (head_t*)calloc(l, sizeof(head_t)); // initialized to 0
  100. size /= sizeof(Qfloat);
  101. size -= l * sizeof(head_t) / sizeof(Qfloat);
  102. size = max(size, 2 * long(l)); // cache must be large enough for two columns
  103. lru_head.next = lru_head.prev = &lru_head;
  104. }
  105. Cache::~Cache()
  106. {
  107. for (head_t* h = lru_head.next; h != &lru_head; h = h->next) { free(h->data); }
  108. free(head);
  109. }
  110. void Cache::lru_delete(head_t* h)
  111. {
  112. // delete from current location
  113. h->prev->next = h->next;
  114. h->next->prev = h->prev;
  115. }
  116. void Cache::lru_insert(head_t* h)
  117. {
  118. // insert to last position
  119. h->next = &lru_head;
  120. h->prev = lru_head.prev;
  121. h->prev->next = h;
  122. h->next->prev = h;
  123. }
  124. int Cache::get_data(const int index, Qfloat** data, int len)
  125. {
  126. head_t* h = &head[index];
  127. if (h->len) { lru_delete(h); }
  128. const int more = len - h->len;
  129. if (more > 0)
  130. {
  131. // free old space
  132. while (size < more)
  133. {
  134. head_t* old = lru_head.next;
  135. lru_delete(old);
  136. free(old->data);
  137. size += old->len;
  138. old->data = nullptr;
  139. old->len = 0;
  140. }
  141. // allocate new space
  142. h->data = (Qfloat*)realloc(h->data, sizeof(Qfloat) * len);
  143. size -= more;
  144. swap(h->len, len);
  145. }
  146. lru_insert(h);
  147. *data = h->data;
  148. return len;
  149. }
  150. void Cache::swap_index(int i, int j)
  151. {
  152. if (i == j) { return; }
  153. if (head[i].len) { lru_delete(&head[i]); }
  154. if (head[j].len) { lru_delete(&head[j]); }
  155. swap(head[i].data, head[j].data);
  156. swap(head[i].len, head[j].len);
  157. if (head[i].len) { lru_insert(&head[i]); }
  158. if (head[j].len) { lru_insert(&head[j]); }
  159. if (i > j) { swap(i, j); }
  160. for (head_t* h = lru_head.next; h != &lru_head; h = h->next)
  161. {
  162. if (h->len > i)
  163. {
  164. if (h->len > j) { swap(h->data[i], h->data[j]); }
  165. else
  166. {
  167. // give up
  168. lru_delete(h);
  169. free(h->data);
  170. size += h->len;
  171. h->data = nullptr;
  172. h->len = 0;
  173. }
  174. }
  175. }
  176. }
  177. //
  178. // Kernel evaluation
  179. //
  180. // the static method k_function is for doing single kernel evaluation
  181. // the constructor of Kernel prepares to calculate the l*l kernel matrix
  182. // the member function get_Q is for getting one column from the Q Matrix
  183. //
  184. class QMatrix
  185. {
  186. public:
  187. virtual Qfloat* get_Q(int column, int len) const = 0;
  188. virtual double* get_QD() const = 0;
  189. virtual void swap_index(int i, int j) const = 0;
  190. virtual ~QMatrix() {}
  191. };
  192. class Kernel : public QMatrix
  193. {
  194. public:
  195. Kernel(const int l, svm_node* const* x_, const svm_parameter& param);
  196. ~Kernel() override;
  197. static double k_function(const svm_node* x, const svm_node* y, const svm_parameter& param);
  198. Qfloat* get_Q(int column, int len) const override = 0;
  199. double* get_QD() const override = 0;
  200. void swap_index(const int i, const int j) const override
  201. // no so const...
  202. {
  203. swap(x[i], x[j]);
  204. if (x_square) { swap(x_square[i], x_square[j]); }
  205. }
  206. protected:
  207. double (Kernel::* kernel_function)(int i, int j) const;
  208. private:
  209. const svm_node** x;
  210. double* x_square;
  211. // svm_parameter
  212. const int kernel_type;
  213. const int degree;
  214. const double gamma;
  215. const double coef0;
  216. static double dot(const svm_node* px, const svm_node* py);
  217. double kernel_linear(const int i, const int j) const { return dot(x[i], x[j]); }
  218. double kernel_poly(const int i, const int j) const { return powi(gamma * dot(x[i], x[j]) + coef0, degree); }
  219. double kernel_rbf(const int i, const int j) const { return exp(-gamma * (x_square[i] + x_square[j] - 2 * dot(x[i], x[j]))); }
  220. double kernel_sigmoid(const int i, const int j) const { return tanh(gamma * dot(x[i], x[j]) + coef0); }
  221. double kernel_precomputed(const int i, const int j) const { return x[i][int(x[j][0].value)].value; }
  222. };
  223. Kernel::Kernel(const int l, svm_node* const* x_, const svm_parameter& param)
  224. : kernel_type(param.kernel_type), degree(param.degree), gamma(param.gamma), coef0(param.coef0)
  225. {
  226. switch (kernel_type)
  227. {
  228. case LINEAR:
  229. kernel_function = &Kernel::kernel_linear;
  230. break;
  231. case POLY:
  232. kernel_function = &Kernel::kernel_poly;
  233. break;
  234. case RBF:
  235. kernel_function = &Kernel::kernel_rbf;
  236. break;
  237. case SIGMOID:
  238. kernel_function = &Kernel::kernel_sigmoid;
  239. break;
  240. case PRECOMPUTED:
  241. kernel_function = &Kernel::kernel_precomputed;
  242. break;
  243. }
  244. clone(x, x_, l);
  245. if (kernel_type == RBF)
  246. {
  247. x_square = new double[l];
  248. for (int i = 0; i < l; i++) { x_square[i] = dot(x[i], x[i]); }
  249. }
  250. else { x_square = nullptr; }
  251. }
  252. Kernel::~Kernel()
  253. {
  254. delete[] x;
  255. delete[] x_square;
  256. }
  257. double Kernel::dot(const svm_node* px, const svm_node* py)
  258. {
  259. double sum = 0;
  260. while (px->index != -1 && py->index != -1)
  261. {
  262. if (px->index == py->index)
  263. {
  264. sum += px->value * py->value;
  265. ++px;
  266. ++py;
  267. }
  268. else
  269. {
  270. if (px->index > py->index) { ++py; }
  271. else { ++px; }
  272. }
  273. }
  274. return sum;
  275. }
  276. double Kernel::k_function(const svm_node* x, const svm_node* y, const svm_parameter& param)
  277. {
  278. switch (param.kernel_type)
  279. {
  280. case LINEAR:
  281. return dot(x, y);
  282. case POLY:
  283. return powi(param.gamma * dot(x, y) + param.coef0, param.degree);
  284. case RBF:
  285. {
  286. double sum = 0;
  287. while (x->index != -1 && y->index != -1)
  288. {
  289. if (x->index == y->index)
  290. {
  291. const double d = x->value - y->value;
  292. sum += d * d;
  293. ++x;
  294. ++y;
  295. }
  296. else
  297. {
  298. if (x->index > y->index)
  299. {
  300. sum += y->value * y->value;
  301. ++y;
  302. }
  303. else
  304. {
  305. sum += x->value * x->value;
  306. ++x;
  307. }
  308. }
  309. }
  310. while (x->index != -1)
  311. {
  312. sum += x->value * x->value;
  313. ++x;
  314. }
  315. while (y->index != -1)
  316. {
  317. sum += y->value * y->value;
  318. ++y;
  319. }
  320. return exp(-param.gamma * sum);
  321. }
  322. case SIGMOID:
  323. return tanh(param.gamma * dot(x, y) + param.coef0);
  324. case PRECOMPUTED: //x: test (validation), y: SV
  325. return x[int(y->value)].value;
  326. default:
  327. return 0; // Unreachable
  328. }
  329. }
  330. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
  331. // Solves:
  332. //
  333. // min 0.5(\alpha^T Q \alpha) + p^T \alpha
  334. //
  335. // y^T \alpha = \delta
  336. // y_i = +1 or -1
  337. // 0 <= alpha_i <= Cp for y_i = 1
  338. // 0 <= alpha_i <= Cn for y_i = -1
  339. //
  340. // Given:
  341. //
  342. // Q, p, y, Cp, Cn, and an initial feasible point \alpha
  343. // l is the size of vectors and matrices
  344. // eps is the stopping tolerance
  345. //
  346. // solution will be put in \alpha, objective value will be put in obj
  347. //
  348. class Solver
  349. {
  350. public:
  351. Solver() {}
  352. virtual ~Solver() {}
  353. struct SolutionInfo
  354. {
  355. double obj;
  356. double rho;
  357. double upper_bound_p;
  358. double upper_bound_n;
  359. double r; // for Solver_NU
  360. };
  361. void Solve(int l, const QMatrix& Q, const double* p_, const schar* y_, double* alpha_, double Cp, double Cn, double eps, SolutionInfo* si, int shrinking);
  362. protected:
  363. enum { LOWER_BOUND, UPPER_BOUND, FREE };
  364. int active_size = 0;
  365. schar* y = nullptr;
  366. double* G = nullptr; // gradient of objective function
  367. char* alpha_status = nullptr; // LOWER_BOUND, UPPER_BOUND, FREE
  368. double* alpha = nullptr;
  369. const QMatrix* Q;
  370. const double* QD;
  371. double eps = 0.0;
  372. double Cp = 0.0, Cn = 0.0;
  373. double* p = nullptr;
  374. int* active_set = nullptr;
  375. double* G_bar = nullptr; // gradient, if we treat free variables as 0
  376. int l = 0.0;
  377. bool unshrink = true; // XXX
  378. double get_C(const int i) { return (y[i] > 0) ? Cp : Cn; }
  379. void update_alpha_status(const int i)
  380. {
  381. if (alpha[i] >= get_C(i)) { alpha_status[i] = UPPER_BOUND; }
  382. else if (alpha[i] <= 0) { alpha_status[i] = LOWER_BOUND; }
  383. else { alpha_status[i] = FREE; }
  384. }
  385. bool is_upper_bound(const int i) { return alpha_status[i] == UPPER_BOUND; }
  386. bool is_lower_bound(const int i) { return alpha_status[i] == LOWER_BOUND; }
  387. bool is_free(const int i) { return alpha_status[i] == FREE; }
  388. void swap_index(const int i, const int j);
  389. void reconstruct_gradient();
  390. virtual int select_working_set(int& out_i, int& out_j);
  391. virtual double calculate_rho();
  392. virtual void do_shrinking();
  393. private:
  394. bool be_shrunk(int i, double Gmax1, double Gmax2);
  395. };
  396. void Solver::swap_index(const int i, const int j)
  397. {
  398. Q->swap_index(i, j);
  399. swap(y[i], y[j]);
  400. swap(G[i], G[j]);
  401. swap(alpha_status[i], alpha_status[j]);
  402. swap(alpha[i], alpha[j]);
  403. swap(p[i], p[j]);
  404. swap(active_set[i], active_set[j]);
  405. swap(G_bar[i], G_bar[j]);
  406. }
  407. void Solver::reconstruct_gradient()
  408. {
  409. // reconstruct inactive elements of G from G_bar and free variables
  410. if (active_size == l) { return; }
  411. int i, j;
  412. int nr_free = 0;
  413. for (j = active_size; j < l; j++) { G[j] = G_bar[j] + p[j]; }
  414. for (j = 0; j < active_size; j++) { if (is_free(j)) { nr_free++; } }
  415. if (2 * nr_free < active_size) { info("\nWARNING: using -h 0 may be faster\n"); }
  416. if (nr_free * l > 2 * active_size * (l - active_size))
  417. {
  418. for (i = active_size; i < l; i++)
  419. {
  420. const Qfloat* Q_i = Q->get_Q(i, active_size);
  421. for (j = 0; j < active_size; j++) { if (is_free(j)) { G[i] += alpha[j] * Q_i[j]; } }
  422. }
  423. }
  424. else
  425. {
  426. for (i = 0; i < active_size; i++)
  427. {
  428. if (is_free(i))
  429. {
  430. const Qfloat* Q_i = Q->get_Q(i, l);
  431. const double alpha_i = alpha[i];
  432. for (j = active_size; j < l; j++) { G[j] += alpha_i * Q_i[j]; }
  433. }
  434. }
  435. }
  436. }
  437. void Solver::Solve(const int l, const QMatrix& Q, const double* p_, const schar* y_, double* alpha_, const double Cp, const double Cn, const double eps,
  438. SolutionInfo* si, const int shrinking)
  439. {
  440. this->l = l;
  441. this->Q = &Q;
  442. QD = Q.get_QD();
  443. clone(p, p_, l);
  444. clone(y, y_, l);
  445. clone(alpha, alpha_, l);
  446. this->Cp = Cp;
  447. this->Cn = Cn;
  448. this->eps = eps;
  449. unshrink = false;
  450. // initialize alpha_status
  451. {
  452. alpha_status = new char[l];
  453. for (int i = 0; i < l; i++) { update_alpha_status(i); }
  454. }
  455. // initialize active set (for shrinking)
  456. {
  457. active_set = new int[l];
  458. for (int i = 0; i < l; i++) { active_set[i] = i; }
  459. active_size = l;
  460. }
  461. // initialize gradient
  462. {
  463. G = new double[l];
  464. G_bar = new double[l];
  465. int i;
  466. for (i = 0; i < l; i++)
  467. {
  468. G[i] = p[i];
  469. G_bar[i] = 0;
  470. }
  471. for (i = 0; i < l; i++)
  472. {
  473. if (!is_lower_bound(i))
  474. {
  475. const Qfloat* Q_i = Q.get_Q(i, l);
  476. const double alpha_i = alpha[i];
  477. int j;
  478. for (j = 0; j < l; j++) { G[j] += alpha_i * Q_i[j]; }
  479. if (is_upper_bound(i)) { for (j = 0; j < l; j++) { G_bar[j] += get_C(i) * Q_i[j]; } }
  480. }
  481. }
  482. }
  483. // optimization step
  484. int iter = 0;
  485. const int max_iter = max(10000000, l > INT_MAX / 100 ? INT_MAX : 100 * l);
  486. int counter = min(l, 1000) + 1;
  487. while (iter < max_iter)
  488. {
  489. // show progress and do shrinking
  490. if (--counter == 0)
  491. {
  492. counter = min(l, 1000);
  493. if (shrinking) { do_shrinking(); }
  494. info(".");
  495. }
  496. int i, j;
  497. if (select_working_set(i, j) != 0)
  498. {
  499. // reconstruct the whole gradient
  500. reconstruct_gradient();
  501. // reset active set size and check
  502. active_size = l;
  503. info("*");
  504. if (select_working_set(i, j) != 0) { break; } // do shrinking next iteration
  505. counter = 1;
  506. }
  507. ++iter;
  508. // update alpha[i] and alpha[j], handle bounds carefully
  509. const Qfloat* Q_i = Q.get_Q(i, active_size);
  510. const Qfloat* Q_j = Q.get_Q(j, active_size);
  511. const double C_i = get_C(i);
  512. const double C_j = get_C(j);
  513. const double old_alpha_i = alpha[i];
  514. const double old_alpha_j = alpha[j];
  515. if (y[i] != y[j])
  516. {
  517. double quad_coef = QD[i] + QD[j] + 2 * Q_i[j];
  518. if (quad_coef <= 0) { quad_coef = TAU; }
  519. const double delta = (-G[i] - G[j]) / quad_coef;
  520. const double diff = alpha[i] - alpha[j];
  521. alpha[i] += delta;
  522. alpha[j] += delta;
  523. if (diff > 0)
  524. {
  525. if (alpha[j] < 0)
  526. {
  527. alpha[j] = 0;
  528. alpha[i] = diff;
  529. }
  530. }
  531. else
  532. {
  533. if (alpha[i] < 0)
  534. {
  535. alpha[i] = 0;
  536. alpha[j] = -diff;
  537. }
  538. }
  539. if (diff > C_i - C_j)
  540. {
  541. if (alpha[i] > C_i)
  542. {
  543. alpha[i] = C_i;
  544. alpha[j] = C_i - diff;
  545. }
  546. }
  547. else
  548. {
  549. if (alpha[j] > C_j)
  550. {
  551. alpha[j] = C_j;
  552. alpha[i] = C_j + diff;
  553. }
  554. }
  555. }
  556. else
  557. {
  558. double quad_coef = QD[i] + QD[j] - 2 * Q_i[j];
  559. if (quad_coef <= 0) { quad_coef = TAU; }
  560. const double delta = (G[i] - G[j]) / quad_coef;
  561. const double sum = alpha[i] + alpha[j];
  562. alpha[i] -= delta;
  563. alpha[j] += delta;
  564. if (sum > C_i)
  565. {
  566. if (alpha[i] > C_i)
  567. {
  568. alpha[i] = C_i;
  569. alpha[j] = sum - C_i;
  570. }
  571. }
  572. else
  573. {
  574. if (alpha[j] < 0)
  575. {
  576. alpha[j] = 0;
  577. alpha[i] = sum;
  578. }
  579. }
  580. if (sum > C_j)
  581. {
  582. if (alpha[j] > C_j)
  583. {
  584. alpha[j] = C_j;
  585. alpha[i] = sum - C_j;
  586. }
  587. }
  588. else
  589. {
  590. if (alpha[i] < 0)
  591. {
  592. alpha[i] = 0;
  593. alpha[j] = sum;
  594. }
  595. }
  596. }
  597. // update G
  598. const double delta_alpha_i = alpha[i] - old_alpha_i;
  599. const double delta_alpha_j = alpha[j] - old_alpha_j;
  600. for (int k = 0; k < active_size; k++) { G[k] += Q_i[k] * delta_alpha_i + Q_j[k] * delta_alpha_j; }
  601. // update alpha_status and G_bar
  602. {
  603. const bool ui = is_upper_bound(i);
  604. const bool uj = is_upper_bound(j);
  605. update_alpha_status(i);
  606. update_alpha_status(j);
  607. int k;
  608. if (ui != is_upper_bound(i))
  609. {
  610. Q_i = Q.get_Q(i, l);
  611. if (ui) { for (k = 0; k < l; k++) { G_bar[k] -= C_i * Q_i[k]; } }
  612. else { for (k = 0; k < l; k++) { G_bar[k] += C_i * Q_i[k]; } }
  613. }
  614. if (uj != is_upper_bound(j))
  615. {
  616. Q_j = Q.get_Q(j, l);
  617. if (uj) { for (k = 0; k < l; k++) { G_bar[k] -= C_j * Q_j[k]; } }
  618. else { for (k = 0; k < l; k++) { G_bar[k] += C_j * Q_j[k]; } }
  619. }
  620. }
  621. }
  622. if (iter >= max_iter)
  623. {
  624. if (active_size < l)
  625. {
  626. // reconstruct the whole gradient to calculate objective value
  627. reconstruct_gradient();
  628. active_size = l;
  629. info("*");
  630. }
  631. fprintf(stderr, "\nWARNING: reaching max number of iterations\n");
  632. }
  633. // calculate rho
  634. si->rho = calculate_rho();
  635. // calculate objective value
  636. {
  637. double v = 0;
  638. for (int i = 0; i < l; i++) { v += alpha[i] * (G[i] + p[i]); }
  639. si->obj = v / 2;
  640. }
  641. // put back the solution
  642. {
  643. for (int i = 0; i < l; i++) { alpha_[active_set[i]] = alpha[i]; }
  644. }
  645. // juggle everything back
  646. /*{
  647. for(int i=0;i<l;i++)
  648. while(active_set[i] != i)
  649. swap_index(i,active_set[i]);
  650. // or Q.swap_index(i,active_set[i]);
  651. }*/
  652. si->upper_bound_p = Cp;
  653. si->upper_bound_n = Cn;
  654. info("\noptimization finished, #iter = %d\n", iter);
  655. delete[] p;
  656. delete[] y;
  657. delete[] alpha;
  658. delete[] alpha_status;
  659. delete[] active_set;
  660. delete[] G;
  661. delete[] G_bar;
  662. }
  663. // return 1 if already optimal, return 0 otherwise
  664. int Solver::select_working_set(int& out_i, int& out_j)
  665. {
  666. // return i,j such that
  667. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  668. // j: minimizes the decrease of obj value
  669. // (if quadratic coefficeint <= 0, replace it with tau)
  670. // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
  671. double Gmax = -INF;
  672. double Gmax2 = -INF;
  673. int Gmax_idx = -1;
  674. int Gmin_idx = -1;
  675. double obj_diff_min = INF;
  676. for (int t = 0; t < active_size; t++)
  677. {
  678. if (y[t] == +1)
  679. {
  680. if (!is_upper_bound(t))
  681. {
  682. if (-G[t] >= Gmax)
  683. {
  684. Gmax = -G[t];
  685. Gmax_idx = t;
  686. }
  687. }
  688. }
  689. else
  690. {
  691. if (!is_lower_bound(t))
  692. {
  693. if (G[t] >= Gmax)
  694. {
  695. Gmax = G[t];
  696. Gmax_idx = t;
  697. }
  698. }
  699. }
  700. }
  701. const int i = Gmax_idx;
  702. const Qfloat* Q_i = nullptr;
  703. if (i != -1) { Q_i = Q->get_Q(i, active_size); } // NULL Q_i not accessed: Gmax=-INF if i=-1
  704. for (int j = 0; j < active_size; j++)
  705. {
  706. if (y[j] == +1)
  707. {
  708. if (!is_lower_bound(j))
  709. {
  710. const double grad_diff = Gmax + G[j];
  711. if (G[j] >= Gmax2) { Gmax2 = G[j]; }
  712. if (grad_diff > 0)
  713. {
  714. double obj_diff;
  715. const double quad_coef = QD[i] + QD[j] - 2.0 * y[i] * Q_i[j];
  716. if (quad_coef > 0) { obj_diff = -(grad_diff * grad_diff) / quad_coef; }
  717. else { obj_diff = -(grad_diff * grad_diff) / TAU; }
  718. if (obj_diff <= obj_diff_min)
  719. {
  720. Gmin_idx = j;
  721. obj_diff_min = obj_diff;
  722. }
  723. }
  724. }
  725. }
  726. else
  727. {
  728. if (!is_upper_bound(j))
  729. {
  730. const double grad_diff = Gmax - G[j];
  731. if (-G[j] >= Gmax2) { Gmax2 = -G[j]; }
  732. if (grad_diff > 0)
  733. {
  734. double obj_diff;
  735. const double quad_coef = QD[i] + QD[j] + 2.0 * y[i] * Q_i[j];
  736. if (quad_coef > 0) { obj_diff = -(grad_diff * grad_diff) / quad_coef; }
  737. else { obj_diff = -(grad_diff * grad_diff) / TAU; }
  738. if (obj_diff <= obj_diff_min)
  739. {
  740. Gmin_idx = j;
  741. obj_diff_min = obj_diff;
  742. }
  743. }
  744. }
  745. }
  746. }
  747. if (Gmax + Gmax2 < eps || Gmin_idx == -1) { return 1; }
  748. out_i = Gmax_idx;
  749. out_j = Gmin_idx;
  750. return 0;
  751. }
  752. bool Solver::be_shrunk(const int i, const double Gmax1, const double Gmax2)
  753. {
  754. if (is_upper_bound(i))
  755. {
  756. if (y[i] == +1) { return (-G[i] > Gmax1); }
  757. return (-G[i] > Gmax2);
  758. }
  759. if (is_lower_bound(i))
  760. {
  761. if (y[i] == +1) { return (G[i] > Gmax2); }
  762. return (G[i] > Gmax1);
  763. }
  764. return (false);
  765. }
  766. void Solver::do_shrinking()
  767. {
  768. int i;
  769. double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
  770. double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
  771. // find maximal violating pair first
  772. for (i = 0; i < active_size; i++)
  773. {
  774. if (y[i] == +1)
  775. {
  776. if (!is_upper_bound(i)) { if (-G[i] >= Gmax1) { Gmax1 = -G[i]; } }
  777. if (!is_lower_bound(i)) { if (G[i] >= Gmax2) { Gmax2 = G[i]; } }
  778. }
  779. else
  780. {
  781. if (!is_upper_bound(i)) { if (-G[i] >= Gmax2) { Gmax2 = -G[i]; } }
  782. if (!is_lower_bound(i)) { if (G[i] >= Gmax1) { Gmax1 = G[i]; } }
  783. }
  784. }
  785. if (unshrink == false && Gmax1 + Gmax2 <= eps * 10)
  786. {
  787. unshrink = true;
  788. reconstruct_gradient();
  789. active_size = l;
  790. info("*");
  791. }
  792. for (i = 0; i < active_size; i++)
  793. {
  794. if (be_shrunk(i, Gmax1, Gmax2))
  795. {
  796. active_size--;
  797. while (active_size > i)
  798. {
  799. if (!be_shrunk(active_size, Gmax1, Gmax2))
  800. {
  801. swap_index(i, active_size);
  802. break;
  803. }
  804. active_size--;
  805. }
  806. }
  807. }
  808. }
  809. double Solver::calculate_rho()
  810. {
  811. double r;
  812. int nr_free = 0;
  813. double ub = INF, lb = -INF, sum_free = 0;
  814. for (int i = 0; i < active_size; i++)
  815. {
  816. const double yG = y[i] * G[i];
  817. if (is_upper_bound(i))
  818. {
  819. if (y[i] == -1) { ub = min(ub, yG); }
  820. else { lb = max(lb, yG); }
  821. }
  822. else if (is_lower_bound(i))
  823. {
  824. if (y[i] == +1) { ub = min(ub, yG); }
  825. else { lb = max(lb, yG); }
  826. }
  827. else
  828. {
  829. ++nr_free;
  830. sum_free += yG;
  831. }
  832. }
  833. if (nr_free > 0) { r = sum_free / nr_free; }
  834. else { r = (ub + lb) / 2; }
  835. return r;
  836. }
  837. //
  838. // Solver for nu-svm classification and regression
  839. //
  840. // additional constraint: e^T \alpha = constant
  841. //
  842. class Solver_NU : public Solver
  843. {
  844. public:
  845. Solver_NU() {}
  846. void Solve(const int l, const QMatrix& Q, const double* p, const schar* y, double* alpha, const double Cp, const double Cn, const double eps,
  847. SolutionInfo* si, const int shrinking)
  848. {
  849. this->si = si;
  850. Solver::Solve(l, Q, p, y, alpha, Cp, Cn, eps, si, shrinking);
  851. }
  852. private:
  853. SolutionInfo* si = nullptr;
  854. int select_working_set(int& out_i, int& out_j) override;
  855. double calculate_rho() override;
  856. bool be_shrunk(const int i, const double Gmax1, const double Gmax2, const double Gmax3, const double Gmax4);
  857. void do_shrinking() override;
  858. };
  859. // return 1 if already optimal, return 0 otherwise
  860. int Solver_NU::select_working_set(int& out_i, int& out_j)
  861. {
  862. // return i,j such that y_i = y_j and
  863. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
  864. // j: minimizes the decrease of obj value
  865. // (if quadratic coefficeint <= 0, replace it with tau)
  866. // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
  867. double Gmaxp = -INF;
  868. double Gmaxp2 = -INF;
  869. int Gmaxp_idx = -1;
  870. double Gmaxn = -INF;
  871. double Gmaxn2 = -INF;
  872. int Gmaxn_idx = -1;
  873. int Gmin_idx = -1;
  874. double obj_diff_min = INF;
  875. for (int t = 0; t < active_size; t++)
  876. {
  877. if (y[t] == +1)
  878. {
  879. if (!is_upper_bound(t))
  880. {
  881. if (-G[t] >= Gmaxp)
  882. {
  883. Gmaxp = -G[t];
  884. Gmaxp_idx = t;
  885. }
  886. }
  887. }
  888. else
  889. {
  890. if (!is_lower_bound(t))
  891. {
  892. if (G[t] >= Gmaxn)
  893. {
  894. Gmaxn = G[t];
  895. Gmaxn_idx = t;
  896. }
  897. }
  898. }
  899. }
  900. const int ip = Gmaxp_idx;
  901. const int in = Gmaxn_idx;
  902. const Qfloat* Q_ip = nullptr;
  903. const Qfloat* Q_in = nullptr;
  904. if (ip != -1) { Q_ip = Q->get_Q(ip, active_size); } // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
  905. if (in != -1) { Q_in = Q->get_Q(in, active_size); }
  906. for (int j = 0; j < active_size; j++)
  907. {
  908. if (y[j] == +1)
  909. {
  910. if (!is_lower_bound(j))
  911. {
  912. const double grad_diff = Gmaxp + G[j];
  913. if (G[j] >= Gmaxp2) { Gmaxp2 = G[j]; }
  914. if (grad_diff > 0)
  915. {
  916. double obj_diff;
  917. const double quad_coef = QD[ip] + QD[j] - 2 * Q_ip[j];
  918. if (quad_coef > 0) { obj_diff = -(grad_diff * grad_diff) / quad_coef; }
  919. else { obj_diff = -(grad_diff * grad_diff) / TAU; }
  920. if (obj_diff <= obj_diff_min)
  921. {
  922. Gmin_idx = j;
  923. obj_diff_min = obj_diff;
  924. }
  925. }
  926. }
  927. }
  928. else
  929. {
  930. if (!is_upper_bound(j))
  931. {
  932. const double grad_diff = Gmaxn - G[j];
  933. if (-G[j] >= Gmaxn2) { Gmaxn2 = -G[j]; }
  934. if (grad_diff > 0)
  935. {
  936. double obj_diff;
  937. const double quad_coef = QD[in] + QD[j] - 2 * Q_in[j];
  938. if (quad_coef > 0) { obj_diff = -(grad_diff * grad_diff) / quad_coef; }
  939. else { obj_diff = -(grad_diff * grad_diff) / TAU; }
  940. if (obj_diff <= obj_diff_min)
  941. {
  942. Gmin_idx = j;
  943. obj_diff_min = obj_diff;
  944. }
  945. }
  946. }
  947. }
  948. }
  949. if (max(Gmaxp + Gmaxp2, Gmaxn + Gmaxn2) < eps || Gmin_idx == -1) { return 1; }
  950. if (y[Gmin_idx] == +1) { out_i = Gmaxp_idx; }
  951. else { out_i = Gmaxn_idx; }
  952. out_j = Gmin_idx;
  953. return 0;
  954. }
  955. bool Solver_NU::be_shrunk(const int i, const double Gmax1, const double Gmax2, const double Gmax3, const double Gmax4)
  956. {
  957. if (is_upper_bound(i))
  958. {
  959. if (y[i] == +1) { return (-G[i] > Gmax1); }
  960. return (-G[i] > Gmax4);
  961. }
  962. if (is_lower_bound(i))
  963. {
  964. if (y[i] == +1) { return (G[i] > Gmax2); }
  965. return (G[i] > Gmax3);
  966. }
  967. return (false);
  968. }
  969. void Solver_NU::do_shrinking()
  970. {
  971. double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
  972. double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
  973. double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
  974. double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
  975. // find maximal violating pair first
  976. int i;
  977. for (i = 0; i < active_size; i++)
  978. {
  979. if (!is_upper_bound(i))
  980. {
  981. if (y[i] == +1) { if (-G[i] > Gmax1) { Gmax1 = -G[i]; } }
  982. else if (-G[i] > Gmax4) { Gmax4 = -G[i]; }
  983. }
  984. if (!is_lower_bound(i))
  985. {
  986. if (y[i] == +1) { if (G[i] > Gmax2) { Gmax2 = G[i]; } }
  987. else if (G[i] > Gmax3) { Gmax3 = G[i]; }
  988. }
  989. }
  990. if (unshrink == false && max(Gmax1 + Gmax2, Gmax3 + Gmax4) <= eps * 10)
  991. {
  992. unshrink = true;
  993. reconstruct_gradient();
  994. active_size = l;
  995. }
  996. for (i = 0; i < active_size; i++)
  997. {
  998. if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
  999. {
  1000. active_size--;
  1001. while (active_size > i)
  1002. {
  1003. if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
  1004. {
  1005. swap_index(i, active_size);
  1006. break;
  1007. }
  1008. active_size--;
  1009. }
  1010. }
  1011. }
  1012. }
  1013. double Solver_NU::calculate_rho()
  1014. {
  1015. int nr_free1 = 0, nr_free2 = 0;
  1016. double ub1 = INF, ub2 = INF;
  1017. double lb1 = -INF, lb2 = -INF;
  1018. double sum_free1 = 0, sum_free2 = 0;
  1019. for (int i = 0; i < active_size; i++)
  1020. {
  1021. if (y[i] == +1)
  1022. {
  1023. if (is_upper_bound(i)) { lb1 = max(lb1, G[i]); }
  1024. else if (is_lower_bound(i)) { ub1 = min(ub1, G[i]); }
  1025. else
  1026. {
  1027. ++nr_free1;
  1028. sum_free1 += G[i];
  1029. }
  1030. }
  1031. else
  1032. {
  1033. if (is_upper_bound(i)) { lb2 = max(lb2, G[i]); }
  1034. else if (is_lower_bound(i)) { ub2 = min(ub2, G[i]); }
  1035. else
  1036. {
  1037. ++nr_free2;
  1038. sum_free2 += G[i];
  1039. }
  1040. }
  1041. }
  1042. double r1, r2;
  1043. if (nr_free1 > 0) { r1 = sum_free1 / nr_free1; }
  1044. else { r1 = (ub1 + lb1) / 2; }
  1045. if (nr_free2 > 0) { r2 = sum_free2 / nr_free2; }
  1046. else { r2 = (ub2 + lb2) / 2; }
  1047. si->r = (r1 + r2) / 2;
  1048. return (r1 - r2) / 2;
  1049. }
  1050. //
  1051. // Q matrices for various formulations
  1052. //
  1053. class SVC_Q : public Kernel
  1054. {
  1055. public:
  1056. SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar* y_)
  1057. : Kernel(prob.l, prob.x, param)
  1058. {
  1059. clone(y, y_, prob.l);
  1060. cache = new Cache(prob.l, long(param.cache_size * (1 << 20)));
  1061. QD = new double[prob.l];
  1062. for (int i = 0; i < prob.l; i++) { QD[i] = (this->*kernel_function)(i, i); }
  1063. }
  1064. Qfloat* get_Q(const int i, const int len) const override
  1065. {
  1066. Qfloat* data;
  1067. int start;
  1068. if ((start = cache->get_data(i, &data, len)) < len)
  1069. {
  1070. for (int j = start; j < len; j++) { data[j] = Qfloat(y[i] * y[j] * (this->*kernel_function)(i, j)); }
  1071. }
  1072. return data;
  1073. }
  1074. double* get_QD() const override { return QD; }
  1075. void swap_index(const int i, const int j) const override
  1076. {
  1077. cache->swap_index(i, j);
  1078. Kernel::swap_index(i, j);
  1079. swap(y[i], y[j]);
  1080. swap(QD[i], QD[j]);
  1081. }
  1082. ~SVC_Q() override
  1083. {
  1084. delete[] y;
  1085. delete cache;
  1086. delete[] QD;
  1087. }
  1088. private:
  1089. schar* y;
  1090. Cache* cache;
  1091. double* QD;
  1092. };
  1093. class ONE_CLASS_Q : public Kernel
  1094. {
  1095. public:
  1096. ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
  1097. : Kernel(prob.l, prob.x, param)
  1098. {
  1099. cache = new Cache(prob.l, long(param.cache_size * (1 << 20)));
  1100. QD = new double[prob.l];
  1101. for (int i = 0; i < prob.l; i++) { QD[i] = (this->*kernel_function)(i, i); }
  1102. }
  1103. Qfloat* get_Q(const int i, const int len) const override
  1104. {
  1105. Qfloat* data;
  1106. int start;
  1107. if ((start = cache->get_data(i, &data, len)) < len) { for (int j = start; j < len; j++) { data[j] = Qfloat((this->*kernel_function)(i, j)); } }
  1108. return data;
  1109. }
  1110. double* get_QD() const override { return QD; }
  1111. void swap_index(const int i, const int j) const override
  1112. {
  1113. cache->swap_index(i, j);
  1114. Kernel::swap_index(i, j);
  1115. swap(QD[i], QD[j]);
  1116. }
  1117. ~ONE_CLASS_Q() override
  1118. {
  1119. delete cache;
  1120. delete[] QD;
  1121. }
  1122. private:
  1123. Cache* cache;
  1124. double* QD;
  1125. };
  1126. class SVR_Q : public Kernel
  1127. {
  1128. public:
  1129. SVR_Q(const svm_problem& prob, const svm_parameter& param)
  1130. : Kernel(prob.l, prob.x, param)
  1131. {
  1132. l = prob.l;
  1133. cache = new Cache(l, long(param.cache_size * (1 << 20)));
  1134. QD = new double[2 * l];
  1135. sign = new schar[2 * l];
  1136. index = new int[2 * l];
  1137. for (int k = 0; k < l; k++)
  1138. {
  1139. sign[k] = 1;
  1140. sign[k + l] = -1;
  1141. index[k] = k;
  1142. index[k + l] = k;
  1143. QD[k] = (this->*kernel_function)(k, k);
  1144. QD[k + l] = QD[k];
  1145. }
  1146. buffer[0] = new Qfloat[2 * l];
  1147. buffer[1] = new Qfloat[2 * l];
  1148. next_buffer = 0;
  1149. }
  1150. void swap_index(const int i, const int j) const override
  1151. {
  1152. swap(sign[i], sign[j]);
  1153. swap(index[i], index[j]);
  1154. swap(QD[i], QD[j]);
  1155. }
  1156. Qfloat* get_Q(const int i, const int len) const override
  1157. {
  1158. Qfloat* data;
  1159. int j, real_i = index[i];
  1160. if (cache->get_data(real_i, &data, l) < l) { for (j = 0; j < l; j++) { data[j] = Qfloat((this->*kernel_function)(real_i, j)); } }
  1161. // reorder and copy
  1162. Qfloat* buf = buffer[next_buffer];
  1163. next_buffer = 1 - next_buffer;
  1164. const schar si = sign[i];
  1165. for (j = 0; j < len; j++) { buf[j] = Qfloat(si) * Qfloat(sign[j]) * data[index[j]]; }
  1166. return buf;
  1167. }
  1168. double* get_QD() const override { return QD; }
  1169. ~SVR_Q() override
  1170. {
  1171. delete cache;
  1172. delete[] sign;
  1173. delete[] index;
  1174. delete[] buffer[0];
  1175. delete[] buffer[1];
  1176. delete[] QD;
  1177. }
  1178. private:
  1179. int l;
  1180. Cache* cache;
  1181. schar* sign;
  1182. int* index;
  1183. mutable int next_buffer;
  1184. Qfloat* buffer[2];
  1185. double* QD;
  1186. };
  1187. //
  1188. // construct and solve various formulations
  1189. //
  1190. static void solve_c_svc(const svm_problem* prob, const svm_parameter* param, double* alpha, Solver::SolutionInfo* si, const double Cp, const double Cn)
  1191. {
  1192. const int l = prob->l;
  1193. double* minus_ones = new double[l];
  1194. schar* y = new schar[l];
  1195. int i;
  1196. for (i = 0; i < l; i++)
  1197. {
  1198. alpha[i] = 0;
  1199. minus_ones[i] = -1;
  1200. if (prob->y[i] > 0) { y[i] = +1; }
  1201. else { y[i] = -1; }
  1202. }
  1203. Solver s;
  1204. s.Solve(l, SVC_Q(*prob, *param, y), minus_ones, y, alpha, Cp, Cn, param->eps, si, param->shrinking);
  1205. double sum_alpha = 0;
  1206. for (i = 0; i < l; i++) { sum_alpha += alpha[i]; }
  1207. if (Cp == Cn) { info("nu = %f\n", sum_alpha / (Cp * prob->l)); }
  1208. for (i = 0; i < l; i++) { alpha[i] *= y[i]; }
  1209. delete[] minus_ones;
  1210. delete[] y;
  1211. }
  1212. static void solve_nu_svc(const svm_problem* prob, const svm_parameter* param, double* alpha, Solver::SolutionInfo* si)
  1213. {
  1214. int i;
  1215. const int l = prob->l;
  1216. const double nu = param->nu;
  1217. schar* y = new schar[l];
  1218. for (i = 0; i < l; i++)
  1219. {
  1220. if (prob->y[i] > 0) { y[i] = +1; }
  1221. else { y[i] = -1; }
  1222. }
  1223. double sum_pos = nu * l / 2;
  1224. double sum_neg = nu * l / 2;
  1225. for (i = 0; i < l; i++)
  1226. {
  1227. if (y[i] == +1)
  1228. {
  1229. alpha[i] = min(1.0, sum_pos);
  1230. sum_pos -= alpha[i];
  1231. }
  1232. else
  1233. {
  1234. alpha[i] = min(1.0, sum_neg);
  1235. sum_neg -= alpha[i];
  1236. }
  1237. }
  1238. double* zeros = new double[l];
  1239. for (i = 0; i < l; i++) { zeros[i] = 0; }
  1240. Solver_NU s;
  1241. s.Solve(l, SVC_Q(*prob, *param, y), zeros, y, alpha, 1.0, 1.0, param->eps, si, param->shrinking);
  1242. const double r = si->r;
  1243. info("C = %f\n", 1 / r);
  1244. for (i = 0; i < l; i++) { alpha[i] *= y[i] / r; }
  1245. si->rho /= r;
  1246. si->obj /= (r * r);
  1247. si->upper_bound_p = 1 / r;
  1248. si->upper_bound_n = 1 / r;
  1249. delete[] y;
  1250. delete[] zeros;
  1251. }
  1252. static void solve_one_class(const svm_problem* prob, const svm_parameter* param, double* alpha, Solver::SolutionInfo* si)
  1253. {
  1254. const int l = prob->l;
  1255. double* zeros = new double[l];
  1256. schar* ones = new schar[l];
  1257. int i;
  1258. const int n = int(param->nu * prob->l); // # of alpha's at upper bound
  1259. for (i = 0; i < n; i++) { alpha[i] = 1; }
  1260. if (n < prob->l) { alpha[n] = param->nu * prob->l - n; }
  1261. for (i = n + 1; i < l; i++) { alpha[i] = 0; }
  1262. for (i = 0; i < l; i++)
  1263. {
  1264. zeros[i] = 0;
  1265. ones[i] = 1;
  1266. }
  1267. Solver s;
  1268. s.Solve(l, ONE_CLASS_Q(*prob, *param), zeros, ones, alpha, 1.0, 1.0, param->eps, si, param->shrinking);
  1269. delete[] zeros;
  1270. delete[] ones;
  1271. }
  1272. static void solve_epsilon_svr(const svm_problem* prob, const svm_parameter* param, double* alpha, Solver::SolutionInfo* si)
  1273. {
  1274. const int l = prob->l;
  1275. double* alpha2 = new double[2 * l];
  1276. double* linear_term = new double[2 * l];
  1277. schar* y = new schar[2 * l];
  1278. int i;
  1279. for (i = 0; i < l; i++)
  1280. {
  1281. alpha2[i] = 0;
  1282. linear_term[i] = param->p - prob->y[i];
  1283. y[i] = 1;
  1284. alpha2[i + l] = 0;
  1285. linear_term[i + l] = param->p + prob->y[i];
  1286. y[i + l] = -1;
  1287. }
  1288. Solver s;
  1289. s.Solve(2 * l, SVR_Q(*prob, *param), linear_term, y, alpha2, param->C, param->C, param->eps, si, param->shrinking);
  1290. double sum_alpha = 0;
  1291. for (i = 0; i < l; i++)
  1292. {
  1293. alpha[i] = alpha2[i] - alpha2[i + l];
  1294. sum_alpha += fabs(alpha[i]);
  1295. }
  1296. info("nu = %f\n", sum_alpha / (param->C * l));
  1297. delete[] alpha2;
  1298. delete[] linear_term;
  1299. delete[] y;
  1300. }
  1301. static void solve_nu_svr(const svm_problem* prob, const svm_parameter* param, double* alpha, Solver::SolutionInfo* si)
  1302. {
  1303. const int l = prob->l;
  1304. const double C = param->C;
  1305. double* alpha2 = new double[2 * l];
  1306. double* linear_term = new double[2 * l];
  1307. schar* y = new schar[2 * l];
  1308. int i;
  1309. double sum = C * param->nu * l / 2;
  1310. for (i = 0; i < l; i++)
  1311. {
  1312. alpha2[i] = alpha2[i + l] = min(sum, C);
  1313. sum -= alpha2[i];
  1314. linear_term[i] = -prob->y[i];
  1315. y[i] = 1;
  1316. linear_term[i + l] = prob->y[i];
  1317. y[i + l] = -1;
  1318. }
  1319. Solver_NU s;
  1320. s.Solve(2 * l, SVR_Q(*prob, *param), linear_term, y, alpha2, C, C, param->eps, si, param->shrinking);
  1321. info("epsilon = %f\n", -si->r);
  1322. for (i = 0; i < l; i++) { alpha[i] = alpha2[i] - alpha2[i + l]; }
  1323. delete[] alpha2;
  1324. delete[] linear_term;
  1325. delete[] y;
  1326. }
  1327. //
  1328. // decision_function
  1329. //
  1330. struct decision_function
  1331. {
  1332. double* alpha;
  1333. double rho;
  1334. };
  1335. static decision_function svm_train_one(const svm_problem* prob, const svm_parameter* param, const double Cp, const double Cn)
  1336. {
  1337. double* alpha = Malloc(double, prob->l);
  1338. Solver::SolutionInfo si;
  1339. switch (param->svm_type)
  1340. {
  1341. case C_SVC:
  1342. solve_c_svc(prob, param, alpha, &si, Cp, Cn);
  1343. break;
  1344. case NU_SVC:
  1345. solve_nu_svc(prob, param, alpha, &si);
  1346. break;
  1347. case ONE_CLASS:
  1348. solve_one_class(prob, param, alpha, &si);
  1349. break;
  1350. case EPSILON_SVR:
  1351. solve_epsilon_svr(prob, param, alpha, &si);
  1352. break;
  1353. case NU_SVR:
  1354. solve_nu_svr(prob, param, alpha, &si);
  1355. break;
  1356. }
  1357. info("obj = %f, rho = %f\n", si.obj, si.rho);
  1358. // output SVs
  1359. int nSV = 0;
  1360. int nBSV = 0;
  1361. for (int i = 0; i < prob->l; i++)
  1362. {
  1363. if (fabs(alpha[i]) > 0)
  1364. {
  1365. ++nSV;
  1366. if (prob->y[i] > 0) { if (fabs(alpha[i]) >= si.upper_bound_p) { ++nBSV; } }
  1367. else { if (fabs(alpha[i]) >= si.upper_bound_n) { ++nBSV; } }
  1368. }
  1369. }
  1370. info("nSV = %d, nBSV = %d\n", nSV, nBSV);
  1371. decision_function f;
  1372. f.alpha = alpha;
  1373. f.rho = si.rho;
  1374. return f;
  1375. }
  1376. // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
  1377. static void sigmoid_train(const int l, const double* dec_values, const double* labels, double& A, double& B)
  1378. {
  1379. double prior1 = 0, prior0 = 0;
  1380. int i;
  1381. for (i = 0; i < l; i++)
  1382. {
  1383. if (labels[i] > 0) { prior1 += 1; }
  1384. else { prior0 += 1; }
  1385. }
  1386. const int max_iter = 100; // Maximal number of iterations
  1387. const double min_step = 1e-10; // Minimal step taken in line search
  1388. const double sigma = 1e-12; // For numerically strict PD of Hessian
  1389. const double eps = 1e-5;
  1390. const double hiTarget = (prior1 + 1.0) / (prior1 + 2.0);
  1391. const double loTarget = 1 / (prior0 + 2.0);
  1392. double* t = Malloc(double, l);
  1393. double fApB, p, q;
  1394. int iter;
  1395. // Initial Point and Initial Fun Value
  1396. A = 0.0;
  1397. B = log((prior0 + 1.0) / (prior1 + 1.0));
  1398. double fval = 0.0;
  1399. for (i = 0; i < l; i++)
  1400. {
  1401. if (labels[i] > 0) { t[i] = hiTarget; }
  1402. else { t[i] = loTarget; }
  1403. fApB = dec_values[i] * A + B;
  1404. if (fApB >= 0) { fval += t[i] * fApB + log(1 + exp(-fApB)); }
  1405. else { fval += (t[i] - 1) * fApB + log(1 + exp(fApB)); }
  1406. }
  1407. for (iter = 0; iter < max_iter; iter++)
  1408. {
  1409. // Update Gradient and Hessian (use H' = H + sigma I)
  1410. double h11 = sigma; // numerically ensures strict PD
  1411. double h22 = sigma;
  1412. double h21 = 0.0;
  1413. double g1 = 0.0;
  1414. double g2 = 0.0;
  1415. for (i = 0; i < l; i++)
  1416. {
  1417. fApB = dec_values[i] * A + B;
  1418. if (fApB >= 0)
  1419. {
  1420. p = exp(-fApB) / (1.0 + exp(-fApB));
  1421. q = 1.0 / (1.0 + exp(-fApB));
  1422. }
  1423. else
  1424. {
  1425. p = 1.0 / (1.0 + exp(fApB));
  1426. q = exp(fApB) / (1.0 + exp(fApB));
  1427. }
  1428. const double d2 = p * q;
  1429. h11 += dec_values[i] * dec_values[i] * d2;
  1430. h22 += d2;
  1431. h21 += dec_values[i] * d2;
  1432. const double d1 = t[i] - p;
  1433. g1 += dec_values[i] * d1;
  1434. g2 += d1;
  1435. }
  1436. // Stopping Criteria
  1437. if (fabs(g1) < eps && fabs(g2) < eps) { break; }
  1438. // Finding Newton direction: -inv(H') * g
  1439. const double det = h11 * h22 - h21 * h21;
  1440. const double dA = -(h22 * g1 - h21 * g2) / det;
  1441. const double dB = -(-h21 * g1 + h11 * g2) / det;
  1442. const double gd = g1 * dA + g2 * dB;
  1443. double stepsize = 1; // Line Search
  1444. while (stepsize >= min_step)
  1445. {
  1446. const double newA = A + stepsize * dA;
  1447. const double newB = B + stepsize * dB;
  1448. // New function value
  1449. double newf = 0.0;
  1450. for (i = 0; i < l; i++)
  1451. {
  1452. fApB = dec_values[i] * newA + newB;
  1453. if (fApB >= 0) { newf += t[i] * fApB + log(1 + exp(-fApB)); }
  1454. else { newf += (t[i] - 1) * fApB + log(1 + exp(fApB)); }
  1455. }
  1456. // Check sufficient decrease
  1457. if (newf < fval + 0.0001 * stepsize * gd)
  1458. {
  1459. A = newA;
  1460. B = newB;
  1461. fval = newf;
  1462. break;
  1463. }
  1464. stepsize = stepsize / 2.0;
  1465. }
  1466. if (stepsize < min_step)
  1467. {
  1468. info("Line search fails in two-class probability estimates\n");
  1469. break;
  1470. }
  1471. }
  1472. if (iter >= max_iter) { info("Reaching maximal iterations in two-class probability estimates\n"); }
  1473. free(t);
  1474. }
  1475. static double sigmoid_predict(const double decision_value, const double A, const double B)
  1476. {
  1477. const double fApB = decision_value * A + B;
  1478. // 1-p used later; avoid catastrophic cancellation
  1479. if (fApB >= 0) { return exp(-fApB) / (1.0 + exp(-fApB)); }
  1480. return 1.0 / (1 + exp(fApB));
  1481. }
  1482. // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
  1483. static void multiclass_probability(const int k, double** r, double* p)
  1484. {
  1485. int t, j;
  1486. int iter = 0, max_iter = max(100, k);
  1487. double** Q = Malloc(double*, k);
  1488. double* Qp = Malloc(double, k);
  1489. const double eps = 0.005 / k;
  1490. for (t = 0; t < k; t++)
  1491. {
  1492. p[t] = 1.0 / k; // Valid if k = 1
  1493. Q[t] = Malloc(double, k);
  1494. Q[t][t] = 0;
  1495. for (j = 0; j < t; j++)
  1496. {
  1497. Q[t][t] += r[j][t] * r[j][t];
  1498. Q[t][j] = Q[j][t];
  1499. }
  1500. for (j = t + 1; j < k; j++)
  1501. {
  1502. Q[t][t] += r[j][t] * r[j][t];
  1503. Q[t][j] = -r[j][t] * r[t][j];
  1504. }
  1505. }
  1506. for (; iter < max_iter; iter++)
  1507. {
  1508. // stopping condition, recalculate QP,pQP for numerical accuracy
  1509. double pQp = 0;
  1510. for (t = 0; t < k; t++)
  1511. {
  1512. Qp[t] = 0;
  1513. for (j = 0; j < k; j++) { Qp[t] += Q[t][j] * p[j]; }
  1514. pQp += p[t] * Qp[t];
  1515. }
  1516. double max_error = 0;
  1517. for (t = 0; t < k; t++)
  1518. {
  1519. const double error = fabs(Qp[t] - pQp);
  1520. if (error > max_error) { max_error = error; }
  1521. }
  1522. if (max_error < eps) { break; }
  1523. for (t = 0; t < k; t++)
  1524. {
  1525. const double diff = (-Qp[t] + pQp) / Q[t][t];
  1526. p[t] += diff;
  1527. pQp = (pQp + diff * (diff * Q[t][t] + 2 * Qp[t])) / (1 + diff) / (1 + diff);
  1528. for (j = 0; j < k; j++)
  1529. {
  1530. Qp[j] = (Qp[j] + diff * Q[t][j]) / (1 + diff);
  1531. p[j] /= (1 + diff);
  1532. }
  1533. }
  1534. }
  1535. if (iter >= max_iter) { info("Exceeds max_iter in multiclass_prob\n"); }
  1536. for (t = 0; t < k; t++) { free(Q[t]); }
  1537. free(Q);
  1538. free(Qp);
  1539. }
  1540. // Cross-validation decision values for probability estimates
  1541. static void svm_binary_svc_probability(const svm_problem* prob, const svm_parameter* param, const double Cp, const double Cn, double& probA, double& probB)
  1542. {
  1543. int i;
  1544. const int nr_fold = 5;
  1545. int* perm = Malloc(int, prob->l);
  1546. double* dec_values = Malloc(double, prob->l);
  1547. // random shuffle
  1548. for (i = 0; i < prob->l; i++) { perm[i] = i; }
  1549. for (i = 0; i < prob->l; i++)
  1550. {
  1551. const int j = i + rand() % (prob->l - i);
  1552. swap(perm[i], perm[j]);
  1553. }
  1554. for (i = 0; i < nr_fold; i++)
  1555. {
  1556. const int begin = i * prob->l / nr_fold;
  1557. const int end = (i + 1) * prob->l / nr_fold;
  1558. int j;
  1559. struct svm_problem subprob;
  1560. subprob.l = prob->l - (end - begin);
  1561. subprob.x = Malloc(struct svm_node*, subprob.l);
  1562. subprob.y = Malloc(double, subprob.l);
  1563. int k = 0;
  1564. for (j = 0; j < begin; j++)
  1565. {
  1566. subprob.x[k] = prob->x[perm[j]];
  1567. subprob.y[k] = prob->y[perm[j]];
  1568. ++k;
  1569. }
  1570. for (j = end; j < prob->l; j++)
  1571. {
  1572. subprob.x[k] = prob->x[perm[j]];
  1573. subprob.y[k] = prob->y[perm[j]];
  1574. ++k;
  1575. }
  1576. int p_count = 0, n_count = 0;
  1577. for (j = 0; j < k; j++)
  1578. {
  1579. if (subprob.y[j] > 0) { p_count++; }
  1580. else { n_count++; }
  1581. }
  1582. if (p_count == 0 && n_count == 0) { for (j = begin; j < end; j++) { dec_values[perm[j]] = 0; } }
  1583. else if (p_count > 0 && n_count == 0) { for (j = begin; j < end; j++) { dec_values[perm[j]] = 1; } }
  1584. else if (p_count == 0 && n_count > 0) { for (j = begin; j < end; j++) { dec_values[perm[j]] = -1; } }
  1585. else
  1586. {
  1587. svm_parameter subparam = *param;
  1588. subparam.probability = 0;
  1589. subparam.C = 1.0;
  1590. subparam.nr_weight = 2;
  1591. subparam.weight_label = Malloc(int, 2);
  1592. subparam.weight = Malloc(double, 2);
  1593. subparam.weight_label[0] = +1;
  1594. subparam.weight_label[1] = -1;
  1595. subparam.weight[0] = Cp;
  1596. subparam.weight[1] = Cn;
  1597. struct svm_model* submodel = svm_train(&subprob, &subparam);
  1598. for (j = begin; j < end; j++)
  1599. {
  1600. svm_predict_values(submodel, prob->x[perm[j]], &(dec_values[perm[j]]));
  1601. // ensure +1 -1 order; reason not using CV subroutine
  1602. dec_values[perm[j]] *= submodel->label[0];
  1603. }
  1604. svm_free_and_destroy_model(&submodel);
  1605. svm_destroy_param(&subparam);
  1606. }
  1607. free(subprob.x);
  1608. free(subprob.y);
  1609. }
  1610. sigmoid_train(prob->l, dec_values, prob->y, probA, probB);
  1611. free(dec_values);
  1612. free(perm);
  1613. }
  1614. // Return parameter of a Laplace distribution
  1615. static double svm_svr_probability(const svm_problem* prob, const svm_parameter* param)
  1616. {
  1617. int i;
  1618. const int nr_fold = 5;
  1619. double* ymv = Malloc(double, prob->l);
  1620. double mae = 0;
  1621. svm_parameter newparam = *param;
  1622. newparam.probability = 0;
  1623. svm_cross_validation(prob, &newparam, nr_fold, ymv);
  1624. for (i = 0; i < prob->l; i++)
  1625. {
  1626. ymv[i] = prob->y[i] - ymv[i];
  1627. mae += fabs(ymv[i]);
  1628. }
  1629. mae /= prob->l;
  1630. const double std = sqrt(2 * mae * mae);
  1631. int count = 0;
  1632. mae = 0;
  1633. for (i = 0; i < prob->l; i++)
  1634. {
  1635. if (fabs(ymv[i]) > 5 * std) { count = count + 1; }
  1636. else { mae += fabs(ymv[i]); }
  1637. }
  1638. mae /= (prob->l - count);
  1639. info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n", mae);
  1640. free(ymv);
  1641. return mae;
  1642. }
  1643. // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
  1644. // perm, length l, must be allocated before calling this subroutine
  1645. static void svm_group_classes(const svm_problem* prob, int* nr_class_ret, int** label_ret, int** start_ret, int** count_ret, int* perm)
  1646. {
  1647. const int l = prob->l;
  1648. int max_nr_class = 16;
  1649. int nr_class = 0;
  1650. int* label = Malloc(int, max_nr_class);
  1651. int* count = Malloc(int, max_nr_class);
  1652. int* data_label = Malloc(int, l);
  1653. int i;
  1654. for (i = 0; i < l; i++)
  1655. {
  1656. const int this_label = int(prob->y[i]);
  1657. int j;
  1658. for (j = 0; j < nr_class; j++)
  1659. {
  1660. if (this_label == label[j])
  1661. {
  1662. ++count[j];
  1663. break;
  1664. }
  1665. }
  1666. data_label[i] = j;
  1667. if (j == nr_class)
  1668. {
  1669. if (nr_class == max_nr_class)
  1670. {
  1671. max_nr_class *= 2;
  1672. label = (int*)realloc(label, max_nr_class * sizeof(int));
  1673. count = (int*)realloc(count, max_nr_class * sizeof(int));
  1674. }
  1675. label[nr_class] = this_label;
  1676. count[nr_class] = 1;
  1677. ++nr_class;
  1678. }
  1679. }
  1680. //
  1681. // Labels are ordered by their first occurrence in the training set.
  1682. // However, for two-class sets with -1/+1 labels and -1 appears first,
  1683. // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.
  1684. //
  1685. if (nr_class == 2 && label[0] == -1 && label[1] == 1)
  1686. {
  1687. swap(label[0], label[1]);
  1688. swap(count[0], count[1]);
  1689. for (i = 0; i < l; i++)
  1690. {
  1691. if (data_label[i] == 0) { data_label[i] = 1; }
  1692. else { data_label[i] = 0; }
  1693. }
  1694. }
  1695. int* start = Malloc(int, nr_class);
  1696. start[0] = 0;
  1697. for (i = 1; i < nr_class; i++) { start[i] = start[i - 1] + count[i - 1]; }
  1698. for (i = 0; i < l; i++)
  1699. {
  1700. perm[start[data_label[i]]] = i;
  1701. ++start[data_label[i]];
  1702. }
  1703. start[0] = 0;
  1704. for (i = 1; i < nr_class; i++) { start[i] = start[i - 1] + count[i - 1]; }
  1705. *nr_class_ret = nr_class;
  1706. *label_ret = label;
  1707. *start_ret = start;
  1708. *count_ret = count;
  1709. free(data_label);
  1710. }
  1711. //
  1712. // Interface functions
  1713. //
  1714. svm_model* svm_train(const svm_problem* prob, const svm_parameter* param)
  1715. {
  1716. svm_model* model = Malloc(svm_model, 1);
  1717. model->param = *param;
  1718. model->free_sv = 0; // XXX
  1719. if (param->svm_type == ONE_CLASS || param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR)
  1720. {
  1721. // regression or one-class-svm
  1722. model->nr_class = 2;
  1723. model->label = nullptr;
  1724. model->nSV = nullptr;
  1725. model->probA = nullptr;
  1726. model->probB = nullptr;
  1727. model->sv_coef = Malloc(double*, 1);
  1728. if (param->probability && (param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR))
  1729. {
  1730. model->probA = Malloc(double, 1);
  1731. model->probA[0] = svm_svr_probability(prob, param);
  1732. }
  1733. decision_function f = svm_train_one(prob, param, 0, 0);
  1734. model->rho = Malloc(double, 1);
  1735. model->rho[0] = f.rho;
  1736. int nSV = 0;
  1737. int i;
  1738. for (i = 0; i < prob->l; i++) { if (fabs(f.alpha[i]) > 0) { ++nSV; } }
  1739. model->l = nSV;
  1740. model->SV = Malloc(svm_node*, nSV);
  1741. model->sv_coef[0] = Malloc(double, nSV);
  1742. model->sv_indices = Malloc(int, nSV);
  1743. int j = 0;
  1744. for (i = 0; i < prob->l; i++)
  1745. {
  1746. if (fabs(f.alpha[i]) > 0)
  1747. {
  1748. model->SV[j] = prob->x[i];
  1749. model->sv_coef[0][j] = f.alpha[i];
  1750. model->sv_indices[j] = i + 1;
  1751. ++j;
  1752. }
  1753. }
  1754. free(f.alpha);
  1755. }
  1756. else
  1757. {
  1758. // classification
  1759. int l = prob->l;
  1760. int nr_class;
  1761. int* label = nullptr;
  1762. int* start = nullptr;
  1763. int* count = nullptr;
  1764. int* perm = Malloc(int, l);
  1765. // group training data of the same class
  1766. svm_group_classes(prob, &nr_class, &label, &start, &count, perm);
  1767. if (nr_class == 1) { info("WARNING: training data in only one class. See README for details.\n"); }
  1768. svm_node** x = Malloc(svm_node*, l);
  1769. int i;
  1770. for (i = 0; i < l; i++) { x[i] = prob->x[perm[i]]; }
  1771. // calculate weighted C
  1772. double* weighted_C = Malloc(double, nr_class);
  1773. for (i = 0; i < nr_class; i++) { weighted_C[i] = param->C; }
  1774. for (i = 0; i < param->nr_weight; i++)
  1775. {
  1776. int j;
  1777. for (j = 0; j < nr_class; j++) { if (param->weight_label[i] == label[j]) { break; } }
  1778. if (j == nr_class) { fprintf(stderr, "WARNING: class label %d specified in weight is not found\n", param->weight_label[i]); }
  1779. else { weighted_C[j] *= param->weight[i]; }
  1780. }
  1781. // train k*(k-1)/2 models
  1782. bool* nonzero = Malloc(bool, l);
  1783. for (i = 0; i < l; i++) { nonzero[i] = false; }
  1784. decision_function* f = Malloc(decision_function, nr_class * (nr_class - 1) / 2);
  1785. double *probA = nullptr, *probB = nullptr;
  1786. if (param->probability)
  1787. {
  1788. probA = Malloc(double, nr_class * (nr_class - 1) / 2);
  1789. probB = Malloc(double, nr_class * (nr_class - 1) / 2);
  1790. }
  1791. int p = 0;
  1792. for (i = 0; i < nr_class; i++)
  1793. {
  1794. for (int j = i + 1; j < nr_class; j++)
  1795. {
  1796. svm_problem sub_prob;
  1797. int si = start[i], sj = start[j];
  1798. int ci = count[i], cj = count[j];
  1799. sub_prob.l = ci + cj;
  1800. sub_prob.x = Malloc(svm_node*, sub_prob.l);
  1801. sub_prob.y = Malloc(double, sub_prob.l);
  1802. int k;
  1803. for (k = 0; k < ci; k++)
  1804. {
  1805. sub_prob.x[k] = x[si + k];
  1806. sub_prob.y[k] = +1;
  1807. }
  1808. for (k = 0; k < cj; k++)
  1809. {
  1810. sub_prob.x[ci + k] = x[sj + k];
  1811. sub_prob.y[ci + k] = -1;
  1812. }
  1813. if (param->probability) { svm_binary_svc_probability(&sub_prob, param, weighted_C[i], weighted_C[j], probA[p], probB[p]); }
  1814. f[p] = svm_train_one(&sub_prob, param, weighted_C[i], weighted_C[j]);
  1815. for (k = 0; k < ci; k++) { if (!nonzero[si + k] && fabs(f[p].alpha[k]) > 0) { nonzero[si + k] = true; } }
  1816. for (k = 0; k < cj; k++) { if (!nonzero[sj + k] && fabs(f[p].alpha[ci + k]) > 0) { nonzero[sj + k] = true; } }
  1817. free(sub_prob.x);
  1818. free(sub_prob.y);
  1819. ++p;
  1820. }
  1821. }
  1822. // build output
  1823. model->nr_class = nr_class;
  1824. model->label = Malloc(int, nr_class);
  1825. for (i = 0; i < nr_class; i++) { model->label[i] = label[i]; }
  1826. model->rho = Malloc(double, nr_class * (nr_class - 1) / 2);
  1827. for (i = 0; i < nr_class * (nr_class - 1) / 2; i++) { model->rho[i] = f[i].rho; }
  1828. if (param->probability)
  1829. {
  1830. model->probA = Malloc(double, nr_class * (nr_class - 1) / 2);
  1831. model->probB = Malloc(double, nr_class * (nr_class - 1) / 2);
  1832. for (i = 0; i < nr_class * (nr_class - 1) / 2; i++)
  1833. {
  1834. model->probA[i] = probA[i];
  1835. model->probB[i] = probB[i];
  1836. }
  1837. }
  1838. else
  1839. {
  1840. model->probA = nullptr;
  1841. model->probB = nullptr;
  1842. }
  1843. int total_sv = 0;
  1844. int* nz_count = Malloc(int, nr_class);
  1845. model->nSV = Malloc(int, nr_class);
  1846. for (i = 0; i < nr_class; i++)
  1847. {
  1848. int nSV = 0;
  1849. for (int j = 0; j < count[i]; j++)
  1850. {
  1851. if (nonzero[start[i] + j])
  1852. {
  1853. ++nSV;
  1854. ++total_sv;
  1855. }
  1856. }
  1857. model->nSV[i] = nSV;
  1858. nz_count[i] = nSV;
  1859. }
  1860. info("Total nSV = %d\n", total_sv);
  1861. model->l = total_sv;
  1862. model->SV = Malloc(svm_node*, total_sv);
  1863. model->sv_indices = Malloc(int, total_sv);
  1864. p = 0;
  1865. for (i = 0; i < l; i++)
  1866. {
  1867. if (nonzero[i])
  1868. {
  1869. model->SV[p] = x[i];
  1870. model->sv_indices[p++] = perm[i] + 1;
  1871. }
  1872. }
  1873. int* nz_start = Malloc(int, nr_class);
  1874. nz_start[0] = 0;
  1875. for (i = 1; i < nr_class; i++) { nz_start[i] = nz_start[i - 1] + nz_count[i - 1]; }
  1876. model->sv_coef = Malloc(double*, nr_class - 1);
  1877. for (i = 0; i < nr_class - 1; i++) { model->sv_coef[i] = Malloc(double, total_sv); }
  1878. p = 0;
  1879. for (i = 0; i < nr_class; i++)
  1880. {
  1881. for (int j = i + 1; j < nr_class; j++)
  1882. {
  1883. // classifier (i,j): coefficients with
  1884. // i are in sv_coef[j-1][nz_start[i]...],
  1885. // j are in sv_coef[i][nz_start[j]...]
  1886. int si = start[i];
  1887. int sj = start[j];
  1888. int ci = count[i];
  1889. int cj = count[j];
  1890. int q = nz_start[i];
  1891. int k;
  1892. for (k = 0; k < ci; k++) { if (nonzero[si + k]) { model->sv_coef[j - 1][q++] = f[p].alpha[k]; } }
  1893. q = nz_start[j];
  1894. for (k = 0; k < cj; k++) { if (nonzero[sj + k]) { model->sv_coef[i][q++] = f[p].alpha[ci + k]; } }
  1895. ++p;
  1896. }
  1897. }
  1898. free(label);
  1899. free(probA);
  1900. free(probB);
  1901. free(count);
  1902. free(perm);
  1903. free(start);
  1904. free(x);
  1905. free(weighted_C);
  1906. free(nonzero);
  1907. for (i = 0; i < nr_class * (nr_class - 1) / 2; i++) { free(f[i].alpha); }
  1908. free(f);
  1909. free(nz_count);
  1910. free(nz_start);
  1911. }
  1912. return model;
  1913. }
  1914. // Stratified cross validation
  1915. void svm_cross_validation(const svm_problem* prob, const svm_parameter* param, int nr_fold, double* target)
  1916. {
  1917. int i;
  1918. const int l = prob->l;
  1919. int* perm = Malloc(int, l);
  1920. int nr_class;
  1921. if (nr_fold > l)
  1922. {
  1923. nr_fold = l;
  1924. fprintf(stderr, "WARNING: # folds > # data. Will use # folds = # data instead (i.e., leave-one-out cross validation)\n");
  1925. }
  1926. int* fold_start = Malloc(int, nr_fold + 1);
  1927. // stratified cv may not give leave-one-out rate
  1928. // Each class to l folds -> some folds may have zero elements
  1929. if ((param->svm_type == C_SVC || param->svm_type == NU_SVC) && nr_fold < l)
  1930. {
  1931. int* start = nullptr;
  1932. int* label = nullptr;
  1933. int* count = nullptr;
  1934. svm_group_classes(prob, &nr_class, &label, &start, &count, perm);
  1935. // random shuffle and then data grouped by fold using the array perm
  1936. int* fold_count = Malloc(int, nr_fold);
  1937. int c;
  1938. int* index = Malloc(int, l);
  1939. for (i = 0; i < l; i++) { index[i] = perm[i]; }
  1940. for (c = 0; c < nr_class; c++)
  1941. {
  1942. for (i = 0; i < count[c]; i++)
  1943. {
  1944. const int j = i + rand() % (count[c] - i);
  1945. swap(index[start[c] + j], index[start[c] + i]);
  1946. }
  1947. }
  1948. for (i = 0; i < nr_fold; i++)
  1949. {
  1950. fold_count[i] = 0;
  1951. for (c = 0; c < nr_class; c++) { fold_count[i] += (i + 1) * count[c] / nr_fold - i * count[c] / nr_fold; }
  1952. }
  1953. fold_start[0] = 0;
  1954. for (i = 1; i <= nr_fold; i++) { fold_start[i] = fold_start[i - 1] + fold_count[i - 1]; }
  1955. for (c = 0; c < nr_class; c++)
  1956. {
  1957. for (i = 0; i < nr_fold; i++)
  1958. {
  1959. const int begin = start[c] + i * count[c] / nr_fold;
  1960. const int end = start[c] + (i + 1) * count[c] / nr_fold;
  1961. for (int j = begin; j < end; j++)
  1962. {
  1963. perm[fold_start[i]] = index[j];
  1964. fold_start[i]++;
  1965. }
  1966. }
  1967. }
  1968. fold_start[0] = 0;
  1969. for (i = 1; i <= nr_fold; i++) { fold_start[i] = fold_start[i - 1] + fold_count[i - 1]; }
  1970. free(start);
  1971. free(label);
  1972. free(count);
  1973. free(index);
  1974. free(fold_count);
  1975. }
  1976. else
  1977. {
  1978. for (i = 0; i < l; i++) { perm[i] = i; }
  1979. for (i = 0; i < l; i++)
  1980. {
  1981. const int j = i + rand() % (l - i);
  1982. swap(perm[i], perm[j]);
  1983. }
  1984. for (i = 0; i <= nr_fold; i++) { fold_start[i] = i * l / nr_fold; }
  1985. }
  1986. for (i = 0; i < nr_fold; i++)
  1987. {
  1988. const int begin = fold_start[i];
  1989. const int end = fold_start[i + 1];
  1990. int j;
  1991. struct svm_problem subprob;
  1992. subprob.l = l - (end - begin);
  1993. subprob.x = Malloc(struct svm_node*, subprob.l);
  1994. subprob.y = Malloc(double, subprob.l);
  1995. int k = 0;
  1996. for (j = 0; j < begin; j++)
  1997. {
  1998. subprob.x[k] = prob->x[perm[j]];
  1999. subprob.y[k] = prob->y[perm[j]];
  2000. ++k;
  2001. }
  2002. for (j = end; j < l; j++)
  2003. {
  2004. subprob.x[k] = prob->x[perm[j]];
  2005. subprob.y[k] = prob->y[perm[j]];
  2006. ++k;
  2007. }
  2008. struct svm_model* submodel = svm_train(&subprob, param);
  2009. if (param->probability && (param->svm_type == C_SVC || param->svm_type == NU_SVC))
  2010. {
  2011. double* prob_estimates = Malloc(double, svm_get_nr_class(submodel));
  2012. for (j = begin; j < end; j++) { target[perm[j]] = svm_predict_probability(submodel, prob->x[perm[j]], prob_estimates); }
  2013. free(prob_estimates);
  2014. }
  2015. else { for (j = begin; j < end; j++) { target[perm[j]] = svm_predict(submodel, prob->x[perm[j]]); } }
  2016. svm_free_and_destroy_model(&submodel);
  2017. free(subprob.x);
  2018. free(subprob.y);
  2019. }
  2020. free(fold_start);
  2021. free(perm);
  2022. }
  2023. int svm_get_svm_type(const svm_model* model) { return model->param.svm_type; }
  2024. int svm_get_nr_class(const svm_model* model) { return model->nr_class; }
  2025. void svm_get_labels(const svm_model* model, int* label)
  2026. {
  2027. if (model->label != nullptr) { for (int i = 0; i < model->nr_class; i++) { label[i] = model->label[i]; } }
  2028. }
  2029. void svm_get_sv_indices(const svm_model* model, int* indices)
  2030. {
  2031. if (model->sv_indices != nullptr) { for (int i = 0; i < model->l; i++) { indices[i] = model->sv_indices[i]; } }
  2032. }
  2033. int svm_get_nr_sv(const svm_model* model) { return model->l; }
  2034. double svm_get_svr_probability(const svm_model* model)
  2035. {
  2036. if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && model->probA != nullptr) { return model->probA[0]; }
  2037. fprintf(stderr, "Model doesn't contain information for SVR probability inference\n");
  2038. return 0;
  2039. }
  2040. double svm_predict_values(const svm_model* model, const svm_node* x, double* dec_values)
  2041. {
  2042. int i;
  2043. if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR)
  2044. {
  2045. double* sv_coef = model->sv_coef[0];
  2046. double sum = 0;
  2047. for (i = 0; i < model->l; i++) { sum += sv_coef[i] * Kernel::k_function(x, model->SV[i], model->param); }
  2048. sum -= model->rho[0];
  2049. *dec_values = sum;
  2050. if (model->param.svm_type == ONE_CLASS) { return (sum > 0) ? 1 : -1; }
  2051. return sum;
  2052. }
  2053. const int nr_class = model->nr_class;
  2054. const int l = model->l;
  2055. double* kvalue = Malloc(double, l);
  2056. for (i = 0; i < l; i++) { kvalue[i] = Kernel::k_function(x, model->SV[i], model->param); }
  2057. int* start = Malloc(int, nr_class);
  2058. start[0] = 0;
  2059. for (i = 1; i < nr_class; i++) { start[i] = start[i - 1] + model->nSV[i - 1]; }
  2060. int* vote = Malloc(int, nr_class);
  2061. for (i = 0; i < nr_class; i++) { vote[i] = 0; }
  2062. int p = 0;
  2063. for (i = 0; i < nr_class; i++)
  2064. {
  2065. for (int j = i + 1; j < nr_class; j++)
  2066. {
  2067. double sum = 0;
  2068. const int si = start[i];
  2069. const int sj = start[j];
  2070. const int ci = model->nSV[i];
  2071. const int cj = model->nSV[j];
  2072. int k;
  2073. double* coef1 = model->sv_coef[j - 1];
  2074. double* coef2 = model->sv_coef[i];
  2075. for (k = 0; k < ci; k++) { sum += coef1[si + k] * kvalue[si + k]; }
  2076. for (k = 0; k < cj; k++) { sum += coef2[sj + k] * kvalue[sj + k]; }
  2077. sum -= model->rho[p];
  2078. dec_values[p] = sum;
  2079. if (dec_values[p] > 0) { ++vote[i]; }
  2080. else { ++vote[j]; }
  2081. p++;
  2082. }
  2083. }
  2084. int vote_max_idx = 0;
  2085. for (i = 1; i < nr_class; i++) { if (vote[i] > vote[vote_max_idx]) { vote_max_idx = i; } }
  2086. free(kvalue);
  2087. free(start);
  2088. free(vote);
  2089. return model->label[vote_max_idx];
  2090. }
  2091. double svm_predict(const svm_model* model, const svm_node* x)
  2092. {
  2093. const int nr_class = model->nr_class;
  2094. double* dec_values;
  2095. if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) { dec_values = Malloc(double, 1); }
  2096. else { dec_values = Malloc(double, nr_class * (nr_class - 1) / 2); }
  2097. const double pred_result = svm_predict_values(model, x, dec_values);
  2098. free(dec_values);
  2099. return pred_result;
  2100. }
  2101. double svm_predict_probability(const svm_model* model, const svm_node* x, double* prob_estimates)
  2102. {
  2103. if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && model->probA != nullptr && model->probB != nullptr)
  2104. {
  2105. int i;
  2106. const int nr_class = model->nr_class;
  2107. double* dec_values = Malloc(double, nr_class * (nr_class - 1) / 2);
  2108. svm_predict_values(model, x, dec_values);
  2109. const double min_prob = 1e-7;
  2110. double** pairwise_prob = Malloc(double*, nr_class);
  2111. for (i = 0; i < nr_class; i++) { pairwise_prob[i] = Malloc(double, nr_class); }
  2112. int k = 0;
  2113. for (i = 0; i < nr_class; i++)
  2114. {
  2115. for (int j = i + 1; j < nr_class; j++)
  2116. {
  2117. pairwise_prob[i][j] = min(max(sigmoid_predict(dec_values[k], model->probA[k], model->probB[k]), min_prob), 1 - min_prob);
  2118. pairwise_prob[j][i] = 1 - pairwise_prob[i][j];
  2119. k++;
  2120. }
  2121. }
  2122. if (nr_class == 2)
  2123. {
  2124. prob_estimates[0] = pairwise_prob[0][1];
  2125. prob_estimates[1] = pairwise_prob[1][0];
  2126. }
  2127. else { multiclass_probability(nr_class, pairwise_prob, prob_estimates); }
  2128. int prob_max_idx = 0;
  2129. for (i = 1; i < nr_class; i++) { if (prob_estimates[i] > prob_estimates[prob_max_idx]) { prob_max_idx = i; } }
  2130. for (i = 0; i < nr_class; i++) { free(pairwise_prob[i]); }
  2131. free(dec_values);
  2132. free(pairwise_prob);
  2133. return model->label[prob_max_idx];
  2134. }
  2135. return svm_predict(model, x);
  2136. }
  2137. static const char* svm_type_table[] = { "c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr", nullptr };
  2138. static const char* kernel_type_table[] = { "linear", "polynomial", "rbf", "sigmoid", "precomputed", nullptr };
  2139. int svm_save_model(const char* model_file_name, const svm_model* model)
  2140. {
  2141. FILE* fp = fopen(model_file_name, "w");
  2142. if (fp == nullptr) { return -1; }
  2143. char* old_locale = setlocale(LC_ALL, nullptr);
  2144. if (old_locale) { old_locale = strdup(old_locale); }
  2145. setlocale(LC_ALL, "C");
  2146. const svm_parameter& param = model->param;
  2147. fprintf(fp, "svm_type %s\n", svm_type_table[param.svm_type]);
  2148. fprintf(fp, "kernel_type %s\n", kernel_type_table[param.kernel_type]);
  2149. if (param.kernel_type == POLY) { fprintf(fp, "degree %d\n", param.degree); }
  2150. if (param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID) { fprintf(fp, "gamma %.17g\n", param.gamma); }
  2151. if (param.kernel_type == POLY || param.kernel_type == SIGMOID) { fprintf(fp, "coef0 %.17g\n", param.coef0); }
  2152. const int nr_class = model->nr_class;
  2153. const int l = model->l;
  2154. fprintf(fp, "nr_class %d\n", nr_class);
  2155. fprintf(fp, "total_sv %d\n", l);
  2156. {
  2157. fprintf(fp, "rho");
  2158. for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) { fprintf(fp, " %.17g", model->rho[i]); }
  2159. fprintf(fp, "\n");
  2160. }
  2161. if (model->label)
  2162. {
  2163. fprintf(fp, "label");
  2164. for (int i = 0; i < nr_class; i++) { fprintf(fp, " %d", model->label[i]); }
  2165. fprintf(fp, "\n");
  2166. }
  2167. if (model->probA) // regression has probA only
  2168. {
  2169. fprintf(fp, "probA");
  2170. for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) { fprintf(fp, " %.17g", model->probA[i]); }
  2171. fprintf(fp, "\n");
  2172. }
  2173. if (model->probB)
  2174. {
  2175. fprintf(fp, "probB");
  2176. for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++) { fprintf(fp, " %.17g", model->probB[i]); }
  2177. fprintf(fp, "\n");
  2178. }
  2179. if (model->nSV)
  2180. {
  2181. fprintf(fp, "nr_sv");
  2182. for (int i = 0; i < nr_class; i++) { fprintf(fp, " %d", model->nSV[i]); }
  2183. fprintf(fp, "\n");
  2184. }
  2185. fprintf(fp, "SV\n");
  2186. const double* const* sv_coef = model->sv_coef;
  2187. const svm_node* const* SV = model->SV;
  2188. for (int i = 0; i < l; i++)
  2189. {
  2190. for (int j = 0; j < nr_class - 1; j++) { fprintf(fp, "%.17g ", sv_coef[j][i]); }
  2191. const svm_node* p = SV[i];
  2192. if (param.kernel_type == PRECOMPUTED) { fprintf(fp, "0:%d ", int(p->value)); }
  2193. else
  2194. {
  2195. while (p->index != -1)
  2196. {
  2197. fprintf(fp, "%d:%.8g ", p->index, p->value);
  2198. p++;
  2199. }
  2200. }
  2201. fprintf(fp, "\n");
  2202. }
  2203. setlocale(LC_ALL, old_locale);
  2204. free(old_locale);
  2205. if (ferror(fp) != 0 || fclose(fp) != 0) { return -1; }
  2206. return 0;
  2207. }
  2208. static char* line = nullptr;
  2209. static int max_line_len;
  2210. static char* readline(FILE* input)
  2211. {
  2212. if (fgets(line, max_line_len, input) == nullptr) { return nullptr; }
  2213. while (strrchr(line, '\n') == nullptr)
  2214. {
  2215. max_line_len *= 2;
  2216. line = (char*)realloc(line, max_line_len);
  2217. const int len = int(strlen(line));
  2218. if (fgets(line + len, max_line_len - len, input) == nullptr) { break; }
  2219. }
  2220. return line;
  2221. }
  2222. //
  2223. // FSCANF helps to handle fscanf failures.
  2224. // Its do-while block avoids the ambiguity when
  2225. // if (...)
  2226. // FSCANF();
  2227. // is used
  2228. //
  2229. #define FSCANF(_stream, _format, _var) do{ if (fscanf(_stream, _format, _var) != 1) return false; }while(0)
  2230. bool read_model_header(FILE* fp, svm_model* model)
  2231. {
  2232. svm_parameter& param = model->param;
  2233. // parameters for training only won't be assigned, but arrays are assigned as NULL for safety
  2234. param.nr_weight = 0;
  2235. param.weight_label = nullptr;
  2236. param.weight = nullptr;
  2237. char cmd[81];
  2238. while (true)
  2239. {
  2240. FSCANF(fp, "%80s", cmd);
  2241. if (strcmp(cmd, "svm_type") == 0)
  2242. {
  2243. FSCANF(fp, "%80s", cmd);
  2244. int i;
  2245. for (i = 0; svm_type_table[i]; i++)
  2246. {
  2247. if (strcmp(svm_type_table[i], cmd) == 0)
  2248. {
  2249. param.svm_type = i;
  2250. break;
  2251. }
  2252. }
  2253. if (svm_type_table[i] == nullptr)
  2254. {
  2255. fprintf(stderr, "unknown svm type.\n");
  2256. return false;
  2257. }
  2258. }
  2259. else if (strcmp(cmd, "kernel_type") == 0)
  2260. {
  2261. FSCANF(fp, "%80s", cmd);
  2262. int i;
  2263. for (i = 0; kernel_type_table[i]; i++)
  2264. {
  2265. if (strcmp(kernel_type_table[i], cmd) == 0)
  2266. {
  2267. param.kernel_type = i;
  2268. break;
  2269. }
  2270. }
  2271. if (kernel_type_table[i] == nullptr)
  2272. {
  2273. fprintf(stderr, "unknown kernel function.\n");
  2274. return false;
  2275. }
  2276. }
  2277. else if (strcmp(cmd, "degree") == 0) { FSCANF(fp, "%d", &param.degree); }
  2278. else if (strcmp(cmd, "gamma") == 0) { FSCANF(fp, "%lf", &param.gamma); }
  2279. else if (strcmp(cmd, "coef0") == 0) { FSCANF(fp, "%lf", &param.coef0); }
  2280. else if (strcmp(cmd, "nr_class") == 0) { FSCANF(fp, "%d", &model->nr_class); }
  2281. else if (strcmp(cmd, "total_sv") == 0) { FSCANF(fp, "%d", &model->l); }
  2282. else if (strcmp(cmd, "rho") == 0)
  2283. {
  2284. const int n = model->nr_class * (model->nr_class - 1) / 2;
  2285. model->rho = Malloc(double, n);
  2286. for (int i = 0; i < n; i++) { FSCANF(fp, "%lf", &model->rho[i]); }
  2287. }
  2288. else if (strcmp(cmd, "label") == 0)
  2289. {
  2290. const int n = model->nr_class;
  2291. model->label = Malloc(int, n);
  2292. for (int i = 0; i < n; i++) { FSCANF(fp, "%d", &model->label[i]); }
  2293. }
  2294. else if (strcmp(cmd, "probA") == 0)
  2295. {
  2296. const int n = model->nr_class * (model->nr_class - 1) / 2;
  2297. model->probA = Malloc(double, n);
  2298. for (int i = 0; i < n; i++) { FSCANF(fp, "%lf", &model->probA[i]); }
  2299. }
  2300. else if (strcmp(cmd, "probB") == 0)
  2301. {
  2302. const int n = model->nr_class * (model->nr_class - 1) / 2;
  2303. model->probB = Malloc(double, n);
  2304. for (int i = 0; i < n; i++) { FSCANF(fp, "%lf", &model->probB[i]); }
  2305. }
  2306. else if (strcmp(cmd, "nr_sv") == 0)
  2307. {
  2308. const int n = model->nr_class;
  2309. model->nSV = Malloc(int, n);
  2310. for (int i = 0; i < n; i++) { FSCANF(fp, "%d", &model->nSV[i]); }
  2311. }
  2312. else if (strcmp(cmd, "SV") == 0)
  2313. {
  2314. while (true)
  2315. {
  2316. const int c = getc(fp);
  2317. if (c == EOF || c == '\n') { break; }
  2318. }
  2319. break;
  2320. }
  2321. else
  2322. {
  2323. fprintf(stderr, "unknown text in model file: [%s]\n", cmd);
  2324. return false;
  2325. }
  2326. }
  2327. return true;
  2328. }
  2329. svm_model* svm_load_model(const char* model_file_name)
  2330. {
  2331. FILE* fp = fopen(model_file_name, "rb");
  2332. if (fp == nullptr) { return nullptr; }
  2333. char* old_locale = setlocale(LC_ALL, nullptr);
  2334. if (old_locale) { old_locale = strdup(old_locale); }
  2335. setlocale(LC_ALL, "C");
  2336. // read parameters
  2337. svm_model* model = Malloc(svm_model, 1);
  2338. model->rho = nullptr;
  2339. model->probA = nullptr;
  2340. model->probB = nullptr;
  2341. model->sv_indices = nullptr;
  2342. model->label = nullptr;
  2343. model->nSV = nullptr;
  2344. // read header
  2345. if (!read_model_header(fp, model))
  2346. {
  2347. fprintf(stderr, "ERROR: fscanf failed to read model\n");
  2348. setlocale(LC_ALL, old_locale);
  2349. free(old_locale);
  2350. free(model->rho);
  2351. free(model->label);
  2352. free(model->nSV);
  2353. free(model);
  2354. return nullptr;
  2355. }
  2356. // read sv_coef and SV
  2357. int elements = 0;
  2358. const long pos = ftell(fp);
  2359. max_line_len = 1024;
  2360. line = Malloc(char, max_line_len);
  2361. char *p, *endptr;
  2362. while (readline(fp) != nullptr)
  2363. {
  2364. p = strtok(line, ":");
  2365. while (true)
  2366. {
  2367. p = strtok(nullptr, ":");
  2368. if (p == nullptr) { break; }
  2369. ++elements;
  2370. }
  2371. }
  2372. elements += model->l;
  2373. fseek(fp, pos, SEEK_SET);
  2374. const int m = model->nr_class - 1;
  2375. const int l = model->l;
  2376. model->sv_coef = Malloc(double*, m);
  2377. int i;
  2378. for (i = 0; i < m; i++) { model->sv_coef[i] = Malloc(double, l); }
  2379. model->SV = Malloc(svm_node*, l);
  2380. svm_node* x_space = nullptr;
  2381. if (l > 0) { x_space = Malloc(svm_node, elements); }
  2382. int j = 0;
  2383. for (i = 0; i < l; i++)
  2384. {
  2385. readline(fp);
  2386. model->SV[i] = &x_space[j];
  2387. p = strtok(line, " \t");
  2388. model->sv_coef[0][i] = strtod(p, &endptr);
  2389. for (int k = 1; k < m; k++)
  2390. {
  2391. p = strtok(nullptr, " \t");
  2392. model->sv_coef[k][i] = strtod(p, &endptr);
  2393. }
  2394. while (true)
  2395. {
  2396. char* idx = strtok(nullptr, ":");
  2397. char* val = strtok(nullptr, " \t");
  2398. if (val == nullptr) { break; }
  2399. x_space[j].index = int(strtol(idx, &endptr, 10));
  2400. x_space[j].value = strtod(val, &endptr);
  2401. ++j;
  2402. }
  2403. x_space[j++].index = -1;
  2404. }
  2405. free(line);
  2406. setlocale(LC_ALL, old_locale);
  2407. free(old_locale);
  2408. if (ferror(fp) != 0 || fclose(fp) != 0) { return nullptr; }
  2409. model->free_sv = 1; // XXX
  2410. return model;
  2411. }
  2412. void svm_free_model_content(svm_model* model_ptr)
  2413. {
  2414. if (model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != nullptr) { free((void*)(model_ptr->SV[0])); }
  2415. if (model_ptr->sv_coef) { for (int i = 0; i < model_ptr->nr_class - 1; i++) { free(model_ptr->sv_coef[i]); } }
  2416. free(model_ptr->SV);
  2417. model_ptr->SV = nullptr;
  2418. free(model_ptr->sv_coef);
  2419. model_ptr->sv_coef = nullptr;
  2420. free(model_ptr->rho);
  2421. model_ptr->rho = nullptr;
  2422. free(model_ptr->label);
  2423. model_ptr->label = nullptr;
  2424. free(model_ptr->probA);
  2425. model_ptr->probA = nullptr;
  2426. free(model_ptr->probB);
  2427. model_ptr->probB = nullptr;
  2428. free(model_ptr->sv_indices);
  2429. model_ptr->sv_indices = nullptr;
  2430. free(model_ptr->nSV);
  2431. model_ptr->nSV = nullptr;
  2432. }
  2433. void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
  2434. {
  2435. if (model_ptr_ptr != nullptr && *model_ptr_ptr != nullptr)
  2436. {
  2437. svm_free_model_content(*model_ptr_ptr);
  2438. free(*model_ptr_ptr);
  2439. *model_ptr_ptr = nullptr;
  2440. }
  2441. }
  2442. void svm_destroy_param(svm_parameter* param)
  2443. {
  2444. free(param->weight_label);
  2445. free(param->weight);
  2446. }
  2447. const char* svm_check_parameter(const svm_problem* prob, const svm_parameter* param)
  2448. {
  2449. // svm_type
  2450. const int svm_type = param->svm_type;
  2451. if (svm_type != C_SVC && svm_type != NU_SVC && svm_type != ONE_CLASS && svm_type != EPSILON_SVR && svm_type != NU_SVR) { return "unknown svm type"; }
  2452. // kernel_type, degree
  2453. const int kernel_type = param->kernel_type;
  2454. if (kernel_type != LINEAR && kernel_type != POLY && kernel_type != RBF && kernel_type != SIGMOID && kernel_type != PRECOMPUTED)
  2455. {
  2456. return "unknown kernel type";
  2457. }
  2458. if ((kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID) && param->gamma < 0) { return "gamma < 0"; }
  2459. if (kernel_type == POLY && param->degree < 0) { return "degree of polynomial kernel < 0"; }
  2460. // cache_size,eps,C,nu,p,shrinking
  2461. if (param->cache_size <= 0) { return "cache_size <= 0"; }
  2462. if (param->eps <= 0) { return "eps <= 0"; }
  2463. if (svm_type == C_SVC || svm_type == EPSILON_SVR || svm_type == NU_SVR) { if (param->C <= 0) { return "C <= 0"; } }
  2464. if (svm_type == NU_SVC || svm_type == ONE_CLASS || svm_type == NU_SVR) { if (param->nu <= 0 || param->nu > 1) { return "nu <= 0 or nu > 1"; } }
  2465. if (svm_type == EPSILON_SVR) { if (param->p < 0) { return "p < 0"; } }
  2466. if (param->shrinking != 0 && param->shrinking != 1) { return "shrinking != 0 and shrinking != 1"; }
  2467. if (param->probability != 0 && param->probability != 1) { return "probability != 0 and probability != 1"; }
  2468. if (param->probability == 1 && svm_type == ONE_CLASS) { return "one-class SVM probability output not supported yet"; }
  2469. // check whether nu-svc is feasible
  2470. if (svm_type == NU_SVC)
  2471. {
  2472. const int l = prob->l;
  2473. int max_nr_class = 16;
  2474. int nr_class = 0;
  2475. int* label = Malloc(int, max_nr_class);
  2476. int* count = Malloc(int, max_nr_class);
  2477. int i;
  2478. for (i = 0; i < l; i++)
  2479. {
  2480. const int this_label = int(prob->y[i]);
  2481. int j;
  2482. for (j = 0; j < nr_class; j++)
  2483. {
  2484. if (this_label == label[j])
  2485. {
  2486. ++count[j];
  2487. break;
  2488. }
  2489. }
  2490. if (j == nr_class)
  2491. {
  2492. if (nr_class == max_nr_class)
  2493. {
  2494. max_nr_class *= 2;
  2495. label = (int*)realloc(label, max_nr_class * sizeof(int));
  2496. count = (int*)realloc(count, max_nr_class * sizeof(int));
  2497. }
  2498. label[nr_class] = this_label;
  2499. count[nr_class] = 1;
  2500. ++nr_class;
  2501. }
  2502. }
  2503. for (i = 0; i < nr_class; i++)
  2504. {
  2505. const int n1 = count[i];
  2506. for (int j = i + 1; j < nr_class; j++)
  2507. {
  2508. const int n2 = count[j];
  2509. if (param->nu * (n1 + n2) / 2 > min(n1, n2))
  2510. {
  2511. free(label);
  2512. free(count);
  2513. return "specified nu is infeasible";
  2514. }
  2515. }
  2516. }
  2517. free(label);
  2518. free(count);
  2519. }
  2520. return nullptr;
  2521. }
  2522. int svm_check_probability_model(const svm_model* model)
  2523. {
  2524. return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && model->probA != nullptr && model->probB != nullptr)
  2525. || ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && model->probA != nullptr);
  2526. }
  2527. void svm_set_print_string_function(void (*print_func)(const char*))
  2528. {
  2529. if (print_func == nullptr) { svm_print_string = &print_string_stdout; }
  2530. else { svm_print_string = print_func; }
  2531. }