diff --git a/include/flucoma/algorithms/public/UMAP.hpp b/include/flucoma/algorithms/public/UMAP.hpp index 3d262a4f..1d6b7b98 100644 --- a/include/flucoma/algorithms/public/UMAP.hpp +++ b/include/flucoma/algorithms/public/UMAP.hpp @@ -172,7 +172,12 @@ class UMAP bool initialized() const { return mInitialized; } DataSet train(DataSet& in, index k = 15, index dims = 2, double minDist = 0.1, - index maxIter = 200, double learningRate = 1.0) + index maxIter = 200, double learningRate = 1.0, + std::optional> labels = std::nullopt, + double sameLabelBoost = 1.5, + double diffLabelPenalty = 0.1, + double unlabeledPenalty = 1.0 + ) { using namespace Eigen; using namespace _impl; @@ -191,6 +196,12 @@ class UMAP computeHighDimProb(dists, sigma, knnGraph); SparseMatrixXd knnGraphT = knnGraph.transpose(); knnGraph = (knnGraph + knnGraphT) - knnGraph.cwiseProduct(knnGraphT); + + if (labels.has_value()) + { + applySemiSupervisedWeights(knnGraph, labels.value(), sameLabelBoost, + diffLabelPenalty, unlabeledPenalty); + } mAB = findAB(minDist); mEmbedding = spectralEmbedding.train(knnGraph, dims); mEmbedding = normalizeEmbedding(mEmbedding); @@ -270,6 +281,35 @@ class UMAP private: + void applySemiSupervisedWeights(SparseMatrixXd& graph, + FluidTensorView labels, + double sameLabelBoost = 1.5, + double diffLabelPenalty = 0.1, + double unlabeledPenalty = 1.0) const + { + const int unlabeledMarker = -1; + + for (index k = 0; k < graph.outerSize(); ++k) + { + for (SparseMatrixXd::InnerIterator it(graph, k); it; ++it) + { + index row = it.row(); + index col = it.col(); + + int labelRow = labels(row); + int labelCol = labels(col); + + if (labelRow == unlabeledMarker || labelCol == unlabeledMarker) + { + it.valueRef() *= unlabeledPenalty; + } + + if (labelRow == labelCol) { it.valueRef() *= sameLabelBoost; } + else { it.valueRef() *= diffLabelPenalty; } + } + } + } + template void traverseGraph(Eigen::SparseCompressedBase& graph, F func) const { diff --git a/include/flucoma/clients/nrt/UMAPClient.hpp b/include/flucoma/clients/nrt/UMAPClient.hpp index 593104a9..e573a290 100644 --- a/include/flucoma/clients/nrt/UMAPClient.hpp +++ b/include/flucoma/clients/nrt/UMAPClient.hpp @@ -11,8 +11,12 @@ under the European Union’s Horizon 2020 research and innovation programme #pragma once #include "DataSetClient.hpp" +#include "LabelSetClient.hpp" #include "NRTClient.hpp" +#include "../common/SharedClientUtils.hpp" #include "../../algorithms/public/UMAP.hpp" +#include +#include namespace fluid { namespace client { @@ -24,7 +28,11 @@ constexpr auto UMAPParams = defineParameters( LongParam("numNeighbours", "Number of Nearest Neighbours", 15, Min(1)), FloatParam("minDist", "Minimum Distance", 0.1, Min(0)), LongParam("iterations", "Number of Iterations", 200, Min(1)), - FloatParam("learnRate", "Learning Rate", 0.1, Min(0.0), Max(1.0))); + FloatParam("learnRate", "Learning Rate", 0.1, Min(0.0), Max(1.0)), + FloatParam("sameLabelBoost", "Same Label Boost Amount", 1.0, Min(0.0)), + FloatParam("diffLabelPenalty", "Different Label Penalty Amount", 1.0, Min(0.0), Max(1.0)), + FloatParam("unlabeledPenalty", "Unlabeled Penalty Amount", 1.0, Min(0.0), Max(1.0)) + ); class UMAPClient : public FluidBaseClient, OfflineIn, @@ -38,7 +46,10 @@ class UMAPClient : public FluidBaseClient, kNumNeighbors, kMinDistance, kNumIter, - kLearningRate + kLearningRate, + kSameLabelBoost, + kDiffLabelPenalty, + kUnlabeledPenalty, }; public: @@ -77,21 +88,82 @@ class UMAPClient : public FluidBaseClient, auto destPtr = destClient.get().lock(); if (!srcPtr || !destPtr) return Error(NoDataSet); auto src = srcPtr->getDataSet(); - auto dest = destPtr->getDataSet(); if (src.size() == 0) return Error(EmptyDataSet); if (get() >= src.size()) return Error("Number of Neighbours is greater or equal to the size of the the dataset"); + FluidDataSet result; try { - result = mAlgorithm.train(src, get(), get(), - get(), get(), - get()); + result = mAlgorithm.train( + src, get(), get(), get(), + get(), get(), + std::nullopt, + get(), get(), get() + ); + } catch (const std::runtime_error& e) //spectra library will throw if eigen decomp fails + { + return {Result::Status::kError, e.what()}; } - catch (const std::runtime_error& e) //spectra library will throw if eigen decomp fails + destPtr->setDataSet(result); + return OK(); + } + + MessageResult semiSupervised(InputDataSetClientRef sourceClient, + DataSetClientRef destClient, + InputLabelSetClientRef labelsClient) + { + auto srcPtr = sourceClient.get().lock(); + auto destPtr = destClient.get().lock(); + if (!srcPtr || !destPtr) return Error(NoDataSet); + auto src = srcPtr->getDataSet(); + if (src.size() == 0) return Error(EmptyDataSet); + if (get() >= src.size()) + return Error("Number of Neighbours is greater or equal to the size of " + "the the dataset"); + + std::optional> labels; + if (auto labelsPtr = labelsClient.get().lock()) + { + auto labelSet = labelsPtr->getLabelSet(); + auto srcIds = src.getIds(); + std::map labelMap; + int nextClassId = 0; + FluidTensor intLabels(src.size()); + intLabels.fill(-1); + + for (index i = 0; i < src.size(); ++i) + { + StringVector stringLabel(1); + if (labelSet.get(srcIds(i), stringLabel)) + { + string currentLabel = stringLabel(0); + if (labelMap.find(currentLabel) == labelMap.end()) + { + labelMap[currentLabel] = nextClassId++; + } + intLabels(i) = labelMap[currentLabel]; + } + } + labels = intLabels; + } + + FluidDataSet result; + try + { + result = mAlgorithm.train( + src, get(), get(), get(), + get(), get(), + labels.has_value() + ? std::optional>(labels.value()) + : std::nullopt, + get(), get(), get() + ); + } catch (const std::runtime_error& e) { return {Result::Status::kError, e.what()}; } + destPtr->setDataSet(result); return OK(); } @@ -104,11 +176,11 @@ class UMAPClient : public FluidBaseClient, if (src.size() == 0) return Error(EmptyDataSet); if (get() > src.size()) return Error("Number of Neighbours is larger than dataset"); - StringVector ids{src.getIds()}; - FluidDataSet result; - result = mAlgorithm.train(src, get(), get(), - get(), get(), - get()); + + mAlgorithm.train(src, get(), get(), + get(), get(), + get()); + return OK(); } @@ -157,6 +229,7 @@ class UMAPClient : public FluidBaseClient, { return defineMessages( makeMessage("fitTransform", &UMAPClient::fitTransform), + makeMessage("semisup", &UMAPClient::semiSupervised), makeMessage("fit", &UMAPClient::fit), makeMessage("transform", &UMAPClient::transform), makeMessage("transformPoint", &UMAPClient::transformPoint),