Line data Source code
1 : /**
2 : * @file cppdescent.cpp
3 : * @author Konstantinos Chousos
4 : * @brief Implementation of the cppdescent library.
5 : * @version 0.1
6 : * @date 2023-10-27
7 : *
8 : * @copyright Copyright (c) 2023
9 : *
10 : */
11 : #include "cppdescent/cppdescent.hpp"
12 : #include <gsl/gsl_blas.h>
13 : #include <gsl/gsl_vector.h>
14 : #include <omp.h>
15 : #include <cmath>
16 : #include <cstdint>
17 : #include <iostream>
18 :
19 : bool verbose = true;
20 : int dimensions = 100;
21 :
22 : //===================================
23 : // Helper functions.
24 : //===================================
25 :
26 18 : int cppdescent::compareGraphVertices(Pointer vertex1, Pointer vertex2) {
27 18 : GraphVertex* gvertex1 = (GraphVertex*)vertex1;
28 18 : GraphVertex* gvertex2 = (GraphVertex*)vertex2;
29 :
30 18 : if (gvertex1->getPos() == gvertex2->getPos())
31 12 : return 0;
32 :
33 6 : return 1;
34 : }
35 :
36 0 : void destroyEdges(GraphVertexPair* pair) {
37 0 : delete pair;
38 0 : }
39 :
40 0 : int cppdescent::compareGraphVertexPairs(Pointer p1, Pointer p2) {
41 0 : GraphVertexPair* pair1 = (GraphVertexPair*)p1;
42 0 : GraphVertexPair* pair2 = (GraphVertexPair*)p2;
43 :
44 0 : GraphVertex* n11 = (GraphVertex*)pair1->getVertex1();
45 0 : GraphVertex* n12 = (GraphVertex*)pair1->getVertex2();
46 0 : GraphVertex* n21 = (GraphVertex*)pair2->getVertex1();
47 0 : GraphVertex* n22 = (GraphVertex*)pair2->getVertex2();
48 :
49 0 : if (n11->getPos() == n21->getPos() && n12->getPos() == n22->getPos())
50 0 : return 0;
51 :
52 0 : return 1;
53 : }
54 :
55 : //===================================
56 : // CPPDescent functions.
57 : //===================================
58 :
59 10 : Vector* cppdescent::readBinData(const char* fp, int dimensions) {
60 10 : FILE* data = fopen(fp, "rb");
61 :
62 : uint32_t N;
63 :
64 10 : fread(&N, sizeof(uint32_t), 1, data);
65 :
66 10 : Vector* elements = new Vector(N, nullptr);
67 :
68 : float value;
69 :
70 210 : for (int i = 0; i < (int)N; i++) {
71 200 : gsl_vector* datapoints = gsl_vector_alloc(dimensions);
72 :
73 20200 : for (int j = 0; j < dimensions; j++) {
74 20000 : fread(&value, sizeof(float), 1, data);
75 20000 : gsl_vector_set(datapoints, j, value);
76 : }
77 :
78 200 : elements->setAt(i, datapoints);
79 : }
80 :
81 10 : fclose(data);
82 :
83 10 : return elements;
84 : }
85 :
86 0 : void cppdescent::writeBinGraph(const char* fp, Graph* graph, int K) {
87 0 : FILE* file = fopen(fp, "w+");
88 :
89 0 : Vector* vec = graph->getVec();
90 0 : uint32_t N = vec->getSize();
91 :
92 : // number of vertices
93 0 : fwrite(&N, sizeof(N), 1, file);
94 :
95 : // number of neighbors
96 0 : fwrite(&K, sizeof(K), 1, file);
97 :
98 : // vertices
99 0 : for (int i = 0; i < (int)N; i++) {
100 0 : GraphVertex* gvertex = (GraphVertex*)vec->getAt(i);
101 0 : gsl_vector* vertex = (gsl_vector*)gvertex->getData();
102 0 : int dimensions = vertex->size;
103 :
104 0 : for (int j = 0; j < dimensions; j++) {
105 0 : float datapoint = gsl_vector_get(vertex, j);
106 0 : fwrite(&datapoint, sizeof(float), 1, file);
107 : }
108 : }
109 :
110 : // edges
111 0 : for (int i = 0; i < (int)N; i++) {
112 0 : GraphVertex* gvertex = (GraphVertex*)vec->getAt(i);
113 : // neighbors
114 0 : Vector* neighbors = gvertex->getNeighbors()->toVector();
115 0 : for (int k = 0; k < K; k++) {
116 : // get the kth neigbor from the vector
117 : Pointer neighbor =
118 0 : (((GraphVertexPair*)neighbors->getAt(k)))->getVertex2();
119 : // find its position in the graph's vector
120 0 : int pos = vec->findPos(neighbor, compareGraphVertices);
121 0 : fwrite(&pos, sizeof(int), 1, file);
122 : }
123 : }
124 :
125 0 : fclose(file);
126 0 : }
127 :
128 0 : Graph* cppdescent::readBinGraph(const char* fp, int dimensions) {
129 0 : FILE* file = fopen(fp, "r");
130 0 : if (file == nullptr)
131 : return nullptr; // LCOV_EXCL_LINE
132 :
133 0 : Graph* graph = new Graph(nullptr, nullptr);
134 :
135 : uint32_t N;
136 : int K;
137 :
138 0 : fread(&N, sizeof(N), 1, file);
139 0 : fread(&K, sizeof(K), 1, file);
140 :
141 : float datapoint;
142 : // read the vertices
143 0 : for (int i = 0; i < (int)N; i++) {
144 0 : gsl_vector* vertex = gsl_vector_alloc(dimensions);
145 :
146 0 : for (int j = 0; j < dimensions; j++) {
147 0 : fread(&datapoint, sizeof(float), 1, file);
148 0 : gsl_vector_set(vertex, j, datapoint);
149 : }
150 :
151 0 : graph->insertVertex(vertex);
152 : }
153 :
154 0 : for (int i = 0; i < (int)N; i++) {
155 0 : GraphVertex* v1 = (GraphVertex*)graph->getVec()->getAt(i);
156 : int pos;
157 0 : for (int k = 0; k < K; k++) {
158 0 : fread(&pos, sizeof(int), 1, file);
159 0 : GraphVertex* v2 = (GraphVertex*)graph->getVec()->getAt(pos);
160 0 : graph->insertEdge(v1, v2);
161 : }
162 : }
163 :
164 0 : fclose(file);
165 0 : return graph;
166 : }
167 :
168 0 : float cppdescent::recall(Graph* bfGraph, Graph* nnGraph, int N, int K) {
169 0 : Vector* bfVertices = bfGraph->getVerticesV();
170 0 : Vector* nnVertices = nnGraph->getVerticesV();
171 :
172 0 : float recall = 0;
173 :
174 0 : for (int node = 0; node < bfVertices->getSize(); node++) {
175 0 : int trueNeighbors = 0;
176 0 : Vector* bfNodeAdjacent = bfGraph->getAdjacentV(bfVertices->getAt(node));
177 :
178 0 : Vector* nnAdjacent = nnGraph->getAdjacentV(nnVertices->getAt(node));
179 :
180 0 : for (int adjacent = 0; adjacent < nnAdjacent->getSize(); adjacent++)
181 0 : if (bfNodeAdjacent->find(nnAdjacent->getAt(adjacent),
182 0 : cppdescent::compareGraphVertexPairs) != nullptr)
183 0 : trueNeighbors++;
184 :
185 0 : recall += (float)trueNeighbors / (float)K;
186 : }
187 :
188 0 : recall = recall / (float)N;
189 0 : recall *= 100;
190 :
191 0 : return recall;
192 : }
193 :
194 0 : Graph* cppdescent::KNNBruteForceGraph(Vector* data,
195 : int K,
196 : CompareFunc compare) {
197 0 : Graph* graph = new Graph(nullptr, nullptr);
198 :
199 : // Insert all points as vertices.
200 0 : int N = data->getSize();
201 0 : for (int i = 0; i < N; i++)
202 0 : graph->insertVertex(data->getAt(i));
203 :
204 0 : for (int i = 0; i < N; i++) {
205 0 : Pointer a = graph->getVec()->getAt(i);
206 0 : PQueue* neighbors = new PQueue(compare, nullptr, nullptr);
207 :
208 0 : for (int j = 0; j < N; j++) {
209 0 : if (i == j)
210 0 : continue;
211 :
212 0 : Pointer b = graph->getVec()->getAt(j);
213 0 : GraphVertexPair* pair = new GraphVertexPair(graph, a, b);
214 :
215 0 : if (neighbors->getSize() < K) {
216 0 : neighbors->insert(pair);
217 0 : } else if (compare(pair, neighbors->getMax()) < 0) {
218 0 : GraphVertexPair* max = (GraphVertexPair*)neighbors->getMax();
219 0 : neighbors->removeMax();
220 0 : delete max;
221 0 : neighbors->insert(pair);
222 : } else {
223 0 : destroyEdges(pair);
224 : }
225 : }
226 :
227 0 : for (int k = 0; k < K; k++) {
228 0 : GraphVertexPair* neighbor = (GraphVertexPair*)neighbors->getMax();
229 0 : neighbors->removeMax();
230 0 : Pointer vec = ((GraphVertexPair*)neighbor)->getVertex2();
231 0 : delete neighbor;
232 0 : graph->insertEdge(a, vec);
233 : }
234 :
235 0 : delete neighbors;
236 : }
237 :
238 0 : return graph;
239 : }
240 :
241 : /**
242 : * @brief Creates a random graph, where each vertex has K random neighbors.
243 : *
244 : * The user is responsible for deallocating the graph.
245 : *
246 : * @param data The Vector with all the points.
247 : * @param K
248 : * @param compare
249 : * @param distance
250 : * @return Graph* The created graph.
251 : */
252 0 : Graph* sampleGraph(Vector* data, int K) {
253 0 : Graph* graph = new Graph(nullptr, nullptr);
254 :
255 0 : int N = data->getSize();
256 :
257 0 : srand(time(0));
258 :
259 : // Insert all points as vertices.
260 0 : for (int i = 0; i < N; i++)
261 0 : graph->insertVertex(data->getAt(i));
262 :
263 : // Iterate all of the vertices.
264 0 : for (int i = 0; i < N; i++) {
265 : // For each vertex, add K random neighbors.
266 0 : Pointer v1 = graph->getVec()->getAt(i);
267 0 : for (int j = 0; j < K; j++) {
268 0 : int randPos = rand() % N; // The position of the neighbor.
269 :
270 : // Avoid adding itself as a neighbor.
271 0 : while (randPos == i)
272 0 : randPos = rand() % N;
273 :
274 : // Get the two vertices and create an edge between them.
275 0 : Pointer v2 = graph->getVec()->getAt(randPos);
276 0 : while (graph->isNeighborVertex(v1, v2) == true) {
277 0 : randPos = rand() % N;
278 0 : v2 = (Pointer)graph->getVec()->getAt(randPos);
279 : }
280 0 : graph->insertEdge(v1, v2);
281 : }
282 : }
283 :
284 0 : return graph;
285 : }
286 :
287 0 : int updateNN(Graph* graph,
288 : int K,
289 : Pointer u1,
290 : Pointer u2,
291 : float dist,
292 : DistanceFunc distance) {
293 0 : PQueue* direct = ((GraphVertex*)u1)->getNeighbors();
294 0 : int size = direct->getSize();
295 :
296 0 : if (size < K) {
297 0 : graph->insertEdge(u1, u2);
298 0 : return 1;
299 : }
300 :
301 0 : Pointer max = ((GraphVertexPair*)direct->getMax())->getVertex2();
302 0 : float maxDist = distance(u1, max);
303 :
304 0 : if (dist < maxDist) {
305 0 : graph->insertEdge(u1, u2);
306 0 : if (direct->getSize() == size + 1) {
307 0 : graph->removeEdge(u1, max);
308 0 : return 1;
309 : }
310 : }
311 :
312 0 : return 0;
313 : }
314 :
315 : struct sets {
316 : Vector* new_v;
317 : Vector* old_v;
318 : };
319 :
320 : /**
321 : * @brief Get the Sets object
322 : *
323 : * Returns a `sets` struct containing a vector pointer to the new[v] set and
324 : * another to the old[v] set.
325 : *
326 : * The first contains rho*K of direct neighbors with
327 : * their flag equal to true and rho*K reverse neighbors with true. In other
328 : * words, it contains 2*rho*K neighbors with flag = true.
329 : *
330 : * The second contains all of the direct neighbors with flag = false, which in
331 : * the worst case will be K, and rho*K of the reverse neighbors with flag =
332 : * false. In other words, K + rho*K neighbors with flag = false.
333 : *
334 : * @param neighbors
335 : * @param K
336 : * @param rho
337 : * @return struct sets
338 : */
339 0 : struct sets getSets(Vector* neighbors, int K, float rho) {
340 0 : int** trueMetadata = new int*[neighbors->getSize()];
341 0 : for (int i = 0; i < neighbors->getSize(); ++i)
342 0 : trueMetadata[i] = new int[3];
343 :
344 0 : int** reverseFalseMetadata = new int*[neighbors->getSize() - K];
345 0 : for (int i = 0; i < neighbors->getSize() - K; ++i)
346 0 : reverseFalseMetadata[i] = new int[2];
347 :
348 0 : for (int i = 0; i < neighbors->getSize(); i++) {
349 : // has been added?
350 0 : trueMetadata[i][1] = 0;
351 : // is it a direct neighbor?
352 0 : trueMetadata[i][2] = 0;
353 : }
354 :
355 0 : for (int i = 0; i < neighbors->getSize() - K; i++)
356 : // has been added?
357 0 : reverseFalseMetadata[i][1] = 0;
358 :
359 : struct sets sets;
360 : // rhoK of direct true and rhoK of reverse true
361 0 : sets.new_v = new Vector(0, nullptr);
362 : // K of direct false and rhoK of reverse false
363 0 : sets.old_v = new Vector(0, nullptr);
364 :
365 0 : int trues = 0;
366 0 : int falses = 0;
367 :
368 0 : for (int i = 0; i < neighbors->getSize(); i++) {
369 0 : GraphVertexPair* pair = (GraphVertexPair*)neighbors->getAt(i);
370 :
371 : // neighbor has flag = true
372 0 : if (pair->getFlag()) {
373 : // in the `trues` array add this neighbor's position
374 0 : trueMetadata[trues][0] = i;
375 0 : pair->setFalse();
376 0 : if (i < K)
377 0 : trueMetadata[trues][2] = 1;
378 0 : trues++;
379 0 : } else if (i < K) {
380 : // neighbor has flag = false and is direct
381 0 : Pointer v = pair->getVertex2();
382 0 : sets.old_v->insertLast(v);
383 : } else {
384 : // neighbor has flag = false and is reverse, so needs sampling
385 0 : reverseFalseMetadata[falses][0] = i;
386 0 : falses++;
387 : }
388 : }
389 :
390 0 : if (trues < 2 * rho * K) {
391 : // if there are less trues than 2ρK, simply put them all
392 0 : for (int i = 0; i < trues; i++) {
393 : Pointer v;
394 0 : if (trueMetadata[i][2])
395 0 : v = ((GraphVertexPair*)neighbors->getAt(i))->getVertex2();
396 : else
397 0 : v = ((GraphVertexPair*)neighbors->getAt(i))->getVertex1();
398 0 : sets.new_v->insertLast(v);
399 : }
400 : } else {
401 : // new[v] sampling
402 0 : for (int i = 0; i < 2 * rho * K; i++) {
403 0 : int selected = rand() % trues;
404 :
405 : // if it already has been selected, choose another
406 0 : while (trueMetadata[selected][1])
407 0 : selected = rand() % trues;
408 :
409 : // has now been selected, do not select again
410 0 : trueMetadata[selected][1] = 1;
411 :
412 : Pointer v;
413 0 : if (trueMetadata[selected][2])
414 0 : v = ((GraphVertexPair*)neighbors->getAt(trueMetadata[selected][0]))
415 0 : ->getVertex2();
416 : else
417 0 : v = ((GraphVertexPair*)neighbors->getAt(trueMetadata[selected][0]))
418 0 : ->getVertex1();
419 :
420 0 : sets.new_v->insertLast(v);
421 : }
422 : }
423 :
424 0 : if (falses < rho * K) {
425 0 : for (int i = 0; i < falses; i++) {
426 0 : Pointer v = ((GraphVertexPair*)neighbors->getAt(i))->getVertex1();
427 0 : sets.old_v->insertLast(v);
428 : }
429 : } else {
430 : // old[v] sampling
431 0 : for (int i = 0; i < rho * K; i++) {
432 0 : int selected = rand() % falses;
433 :
434 : // if it already has been selected, choose another
435 0 : while (reverseFalseMetadata[selected][1])
436 0 : selected = rand() % falses;
437 :
438 : // has now been selected, do not select again
439 0 : reverseFalseMetadata[selected][1] = 1;
440 :
441 0 : Pointer v = ((GraphVertexPair*)neighbors->getAt(
442 0 : reverseFalseMetadata[selected][0]))
443 0 : ->getVertex1();
444 0 : sets.old_v->insertLast(v);
445 : }
446 : }
447 :
448 0 : for (int i = 0; i < neighbors->getSize(); ++i)
449 0 : delete[] trueMetadata[i];
450 0 : delete[] trueMetadata;
451 :
452 0 : for (int i = 0; i < neighbors->getSize() - K; ++i)
453 0 : delete[] reverseFalseMetadata[i];
454 0 : delete[] reverseFalseMetadata;
455 :
456 0 : return sets;
457 : }
458 :
459 0 : Graph* cppdescent::NNDescent_KNNGraph(Vector* data,
460 : int K,
461 : int D,
462 : int Trees,
463 : float delta,
464 : float rho,
465 : DistanceFunc distance) {
466 0 : if (verbose)
467 0 : std::cout << "\tInitializing starting graph...\n";
468 :
469 0 : Graph* graph = nullptr;
470 :
471 0 : if (D != 0) {
472 0 : std::cout << "\tUsing random projection tree...\n";
473 0 : graph = new Graph(nullptr, nullptr);
474 0 : for (int i = 0; i < data->getSize(); i++)
475 0 : graph->insertVertex(data->getAt(i));
476 0 : for (int i = 0; i < Trees; i++)
477 0 : RPTree(graph, nullptr, K, D, dimensions);
478 : } else
479 0 : graph = sampleGraph(data, K);
480 :
481 0 : if (verbose)
482 0 : std::cout << "\tStarting graph has been created\n";
483 : // The vertices do not change, only the edges between them are modified. So we
484 : // only need to get them once and not in each iteration.
485 0 : Vector* vertices = graph->getVerticesV();
486 0 : int N = graph->getSize();
487 : int c;
488 0 : int iterations = 0;
489 : float dist;
490 :
491 0 : struct sets* allSets = new struct sets[N];
492 :
493 : do {
494 0 : iterations++;
495 :
496 0 : c = 0;
497 :
498 0 : #pragma omp parallel for
499 : for (int v = 0; v < N; v++) {
500 : // vAll = Bbar[v] = B[v] ⋃ R[v]
501 : Vector* vAll = graph->getGeneralNeighborsV(vertices->getAt(v));
502 :
503 : allSets[v] = getSets(vAll, K, rho);
504 :
505 : delete vAll;
506 : }
507 :
508 0 : #pragma omp parallel for ordered schedule(dynamic)
509 : for (int v = 0; v < N; v++) {
510 : struct sets sets = allSets[v];
511 :
512 : Vector* new_v = sets.new_v;
513 : Vector* old_v = sets.old_v;
514 :
515 : for (int U1 = 0; U1 < new_v->getSize(); U1++) {
516 : for (int U2 = U1 + 1; U2 < new_v->getSize(); U2++) {
517 : GraphVertex* u1 = (GraphVertex*)new_v->getAt(U1);
518 : GraphVertex* u2 = (GraphVertex*)new_v->getAt(U2);
519 :
520 : dist = distance(u1, u2);
521 :
522 : #pragma omp ordered
523 : c += updateNN(graph, K, u1, u2, dist, distance);
524 : #pragma omp ordered
525 : c += updateNN(graph, K, u2, u1, dist, distance);
526 : }
527 :
528 : for (int U2 = 0; U2 < old_v->getSize(); U2++) {
529 : GraphVertex* u1 = (GraphVertex*)new_v->getAt(U1);
530 : GraphVertex* u2 = (GraphVertex*)old_v->getAt(U2);
531 :
532 : dist = distance(u1, u2);
533 :
534 : #pragma omp ordered
535 : c += updateNN(graph, K, u1, u2, dist, distance);
536 : #pragma omp ordered
537 : c += updateNN(graph, K, u2, u1, dist, distance);
538 : }
539 : }
540 :
541 : delete new_v;
542 : delete old_v;
543 : }
544 :
545 0 : if (verbose)
546 0 : std::cout << "\tNumber of changes in the graph (c) = " << c << "\n";
547 0 : } while (c >= delta * N * K);
548 :
549 0 : delete[] allSets;
550 :
551 0 : if (verbose)
552 0 : std::cout << "\tNN-Descent iterations: " << iterations << "\n";
553 :
554 0 : return graph;
555 : }
556 :
557 : // LCOV_EXCL_START
558 : // PQueue* cppdescent::NNDescent_Query(Graph* graph,
559 : // int K,
560 : // CompareFunc compare,
561 : // Vector* query) {
562 : // List* vertices = graph->getVertices();
563 :
564 : // srand(time(0));
565 :
566 : // int dimensions =
567 : // ((Vector*)((GraphVertex*)((Vector*)graph->getVec()->first()->getValue()))
568 : // ->getData())
569 : // ->getSize();
570 : // if (query->getSize() != dimensions)
571 : // return nullptr;
572 :
573 : // // get random candidate from graph
574 : // int pos = rand() % vertices->getSize();
575 : // Pointer candidate = ((GraphVertex*)graph->getVec()->getAt(pos))->getData();
576 :
577 : // PQueue* knn = new PQueue(compare, (DestroyFunc)destroyEdges, nullptr);
578 :
579 : // List* candidates;
580 : // bool candidatesRemain = true;
581 :
582 : // GraphVertex* queryVertex = new GraphVertex(query, graph);
583 :
584 : // while (candidatesRemain) {
585 : // candidatesRemain = false;
586 :
587 : // // get candidate's neighbors
588 : // candidates = graph->getGeneralNeighborsVertices(candidate);
589 :
590 : // // add best candidate's neighbors to the queue
591 : // for (ListNode* node = candidates->getHead(); node != nullptr;
592 : // node = node->getNext())
593 : // if (!((GraphVertex*)node->getValue())->checked()) {
594 : // ((GraphVertex*)node->getValue())->check();
595 :
596 : // candidatesRemain = true;
597 :
598 : // GraphVertexPair* pair =
599 : // new GraphVertexPair(graph, queryVertex, node->getValue());
600 :
601 : // knn->insert(pair);
602 : // }
603 :
604 : // // truncuate queue to K
605 : // while (knn->getSize() > K)
606 : // knn->removeMax();
607 :
608 : // // get new best candidate
609 : // candidate =
610 : // ((GraphVertex*)((GraphVertexPair*)knn->getMin())->getVertex2())
611 : // ->getData();
612 :
613 : // delete candidates;
614 : // }
615 :
616 : // delete vertices;
617 : // delete queryVertex;
618 : // return knn;
619 : // }
620 : // LCOV_EXCL_STOP
621 :
622 : // ============================ Metric Functions =============================
623 :
624 0 : float cppdescent::euclideanDistance(Pointer a, Pointer b) {
625 0 : GraphVertex* first = (GraphVertex*)a;
626 0 : GraphVertex* second = (GraphVertex*)b;
627 0 : float result = 0;
628 :
629 0 : double x2 = first->getNorm();
630 0 : double y2 = second->getNorm();
631 :
632 : double xy;
633 0 : gsl_blas_ddot((gsl_vector*)first->getData(), (gsl_vector*)second->getData(),
634 : &xy);
635 :
636 0 : result = x2 + y2 - 2 * xy;
637 :
638 0 : return result;
639 : }
640 :
641 0 : int cppdescent::compareEdgesEuclidean(Pointer first, Pointer second) {
642 0 : GraphVertexPair* pair1 = (GraphVertexPair*)first;
643 0 : GraphVertexPair* pair2 = (GraphVertexPair*)second;
644 :
645 0 : float a = euclideanDistance(pair1->getVertex1(), pair1->getVertex2());
646 0 : float b = euclideanDistance(pair2->getVertex1(), pair2->getVertex2());
647 :
648 0 : int value = 0;
649 0 : if (b > a) {
650 0 : value = -1;
651 0 : } else if (a > b) {
652 0 : value = 1;
653 : }
654 0 : return value;
655 : }
656 :
657 : //===================================
658 : // RPTrees functions.
659 : //===================================
660 :
661 0 : void cppdescent::RPTree(Graph* graph,
662 : Vector* vec,
663 : int K,
664 : int D,
665 : int dimensions) {
666 0 : const float epsilon = 1e-8;
667 :
668 : // first call of function, whole graph is passed
669 0 : if (vec == nullptr)
670 0 : vec = graph->getVerticesV();
671 :
672 0 : int size = vec->getSize();
673 :
674 : // We are at a leaf
675 0 : if (size <= D) {
676 : // connect all vertices in the leaf with eachother
677 0 : #pragma omp critical
678 : {
679 0 : for (int i = 0; i < size; i++)
680 0 : for (int j = 0; j < size; j++)
681 0 : if (i != j) {
682 0 : float dist = euclideanDistance(vec->getAt(i), vec->getAt(j));
683 0 : updateNN(graph, K, vec->getAt(i), vec->getAt(j), dist,
684 : euclideanDistance);
685 : }
686 :
687 0 : for (int d = size; d <= K; d++) {
688 0 : int randPos = rand() % graph->getSize();
689 0 : while (vec->find(graph->getVerticesV()->getAt(randPos),
690 0 : compareGraphVertices) != nullptr)
691 0 : randPos = rand() % graph->getSize();
692 :
693 0 : for (int i = 0; i < size; i++) {
694 : float dist =
695 0 : euclideanDistance(vec->getAt(i), graph->getVec()->getAt(randPos));
696 0 : updateNN(graph, K, vec->getAt(i),
697 : graph->getVerticesV()->getAt(randPos), dist,
698 : euclideanDistance);
699 : }
700 : }
701 : }
702 0 : return;
703 : }
704 :
705 0 : int pos0 = rand() % size;
706 0 : int pos1 = rand() % size;
707 :
708 0 : if (pos0 == pos1)
709 0 : pos0 = (pos1 + 1) % size;
710 :
711 0 : GraphVertex* g0 = (GraphVertex*)vec->getAt(pos0);
712 0 : GraphVertex* g1 = (GraphVertex*)vec->getAt(pos1);
713 :
714 0 : gsl_vector* midpoint = gsl_vector_alloc(dimensions);
715 0 : gsl_vector* hyperplane = gsl_vector_alloc(dimensions);
716 :
717 0 : #pragma omp parallel for
718 : for (int i = 0; i < dimensions; i++) {
719 : float mid_value = (gsl_vector_get((gsl_vector*)g0->getData(), i) +
720 : gsl_vector_get((gsl_vector*)g1->getData(), i)) /
721 : 2.0;
722 : gsl_vector_set(midpoint, i, mid_value);
723 : float hyper_value = (gsl_vector_get((gsl_vector*)g0->getData(), i) -
724 : gsl_vector_get((gsl_vector*)g1->getData(), i));
725 : gsl_vector_set(hyperplane, i, hyper_value);
726 : }
727 :
728 : double offset;
729 0 : gsl_blas_ddot(midpoint, hyperplane, &offset);
730 :
731 0 : int cnt0 = 0;
732 0 : int cnt1 = 0;
733 :
734 0 : int* side = new int[size];
735 :
736 : // split the vertices
737 0 : for (int i = 0; i < size; i++) {
738 0 : GraphVertex* gi = (GraphVertex*)vec->getAt(i);
739 0 : gsl_vector* veci = (gsl_vector*)gi->getData();
740 :
741 : double margin;
742 0 : gsl_blas_ddot(hyperplane, veci, &margin);
743 :
744 0 : if (margin > offset + epsilon) {
745 0 : cnt1++;
746 0 : side[i] = 1;
747 0 : } else if (margin < offset - epsilon) {
748 0 : cnt0++;
749 0 : side[i] = 0;
750 0 : } else if (rand() % 2 == 0) {
751 0 : cnt1++;
752 0 : side[i] = 1;
753 : } else {
754 0 : cnt0++;
755 0 : side[i] = 0;
756 : }
757 : }
758 :
759 : // if all vertices are on one side
760 0 : if (cnt1 == 0 || cnt0 == 0) {
761 0 : cnt0 = 0;
762 0 : cnt1 = 0;
763 0 : for (int i = 0; i < size; ++i) {
764 0 : side[i] = rand() % 2;
765 0 : if (side[i] == 0) {
766 0 : ++cnt0;
767 : } else {
768 0 : ++cnt1;
769 : }
770 : }
771 : }
772 :
773 0 : Vector* side0 = new Vector(cnt0, nullptr);
774 0 : Vector* side1 = new Vector(cnt1, nullptr);
775 0 : cnt0 = 0;
776 0 : cnt1 = 0;
777 :
778 0 : for (int i = 0; i < size; i++) {
779 0 : if (side[i] == 0) {
780 0 : side0->setAt(cnt0, vec->getAt(i));
781 0 : cnt0++;
782 : } else {
783 0 : side1->setAt(cnt1, vec->getAt(i));
784 0 : cnt1++;
785 : }
786 : }
787 0 : #pragma omp parallel sections
788 : {
789 : { RPTree(graph, side0, K, D, dimensions); }
790 : #pragma omp section
791 : { RPTree(graph, side1, K, D, dimensions); }
792 : }
793 :
794 0 : gsl_vector_free(midpoint);
795 0 : gsl_vector_free(hyperplane);
796 0 : delete[] side;
797 0 : delete side0;
798 0 : delete side1;
799 : }
|