diff --git a/include/flucoma/algorithms/public/KMeans.hpp b/include/flucoma/algorithms/public/KMeans.hpp index 3302e888..ed58a93f 100644 --- a/include/flucoma/algorithms/public/KMeans.hpp +++ b/include/flucoma/algorithms/public/KMeans.hpp @@ -232,6 +232,7 @@ class KMeans index size() const { return mMeans.rows(); } index getK() const { return mMeans.rows(); } index nAssigned() const { return mAssignments.size(); } + index nEmpty() const { return std::count(mEmpty.begin(), mEmpty.end(), true); } void getAssignments(FluidTensorView out) const { @@ -295,7 +296,6 @@ class KMeans } if (kAssignment.size() == 0) { - std::cout << "Warning: empty cluster" << std::endl; mEmpty[asUnsigned(k)] = true; return; } diff --git a/include/flucoma/clients/nrt/KMeansClient.hpp b/include/flucoma/clients/nrt/KMeansClient.hpp index 45a2906a..1ffab321 100644 --- a/include/flucoma/clients/nrt/KMeansClient.hpp +++ b/include/flucoma/clients/nrt/KMeansClient.hpp @@ -85,11 +85,9 @@ class KMeansClient : public FluidBaseClient, if (dataSet.size() == 0) return Error(EmptyDataSet); if (k <= 1) return Error(SmallK); if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); - return getCounts(assignments, k); + auto [result, _] = + train(dataSet, k, maxIter, get(), get()); + return result; } MessageResult fitPredict(InputDataSetClientRef datasetClient, @@ -106,13 +104,11 @@ class KMeansClient : public FluidBaseClient, if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); + auto [result, assignments] = + train(dataSet, k, maxIter, get(), get()); StringVectorView ids = dataSet.getIds(); labelsetClientPtr->setLabelSet(getLabels(ids, assignments)); - return getCounts(assignments, k); + return result; } MessageResult predict(InputDataSetClientRef datasetClient, @@ -175,12 +171,10 @@ class KMeansClient : public FluidBaseClient, if (dataSet.size() == 0) return Error(EmptyDataSet); if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); + auto [result, _] = + train(dataSet, k, maxIter, get(), get()); transform(srcClient, dstClient); - return getCounts(assignments, k); + return result; } MessageResult predictPoint(InputBufferPtr data) const @@ -263,6 +257,25 @@ class KMeansClient : public FluidBaseClient, private: + using DataSet = FluidDataSet; + + std::pair, IndexVector> + train(DataSet const& dataSet, index k, index maxIter , index initMethod, index randomSeed) + { + mAlgorithm.train( + dataSet, k, maxIter, + static_cast(initMethod), randomSeed); + IndexVector assignments(dataSet.size()); + mAlgorithm.getAssignments(assignments); + auto training_result = MessageResult(getCounts(assignments,k)); + if(mAlgorithm.nEmpty() > 0) + { + training_result.set(Result::Status::kWarning); + training_result.addMessage("There were empty clusters; perhaps numClusters is too high."); + } + return {training_result, assignments}; + } + IndexVector getCounts(IndexVector assignments, index k) const { IndexVector counts(k); diff --git a/include/flucoma/clients/nrt/SKMeansClient.hpp b/include/flucoma/clients/nrt/SKMeansClient.hpp index f6b2f41c..afd776d2 100644 --- a/include/flucoma/clients/nrt/SKMeansClient.hpp +++ b/include/flucoma/clients/nrt/SKMeansClient.hpp @@ -84,11 +84,8 @@ class SKMeansClient : public FluidBaseClient, if (dataSet.size() == 0) return Error(EmptyDataSet); if (k <= 1) return Error(SmallK); if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); - return getCounts(assignments, k); + auto [result, _] = train(dataSet, k, maxIter, get(), get()); + return result; } @@ -106,13 +103,10 @@ class SKMeansClient : public FluidBaseClient, if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); +auto [result, assignments] = train(dataSet, k, maxIter, get(), get()); StringVectorView ids = dataSet.getIds(); labelsetClientPtr->setLabelSet(getLabels(ids, assignments)); - return getCounts(assignments, k); + return result; } @@ -178,12 +172,9 @@ class SKMeansClient : public FluidBaseClient, if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter, static_cast(get()), - get()); - IndexVector assignments(dataSet.size()); - mAlgorithm.getAssignments(assignments); +auto [result, _] = train(dataSet, k, maxIter, get(), get()); encode(srcClient, dstClient); - return getCounts(assignments, k); + return result; } MessageResult predictPoint(BufferPtr data) const @@ -266,6 +257,25 @@ class SKMeansClient : public FluidBaseClient, private: + using DataSet = FluidDataSet; + + std::pair, IndexVector> + train(DataSet const& dataSet, index k, index maxIter , index initMethod, index randomSeed) + { + mAlgorithm.train( + dataSet, k, maxIter, + static_cast(initMethod), randomSeed); + IndexVector assignments(dataSet.size()); + mAlgorithm.getAssignments(assignments); + auto training_result = MessageResult(getCounts(assignments,k)); + if(mAlgorithm.nEmpty() > 0) + { + training_result.set(Result::Status::kWarning); + training_result.addMessage("There were empty clusters; perhaps numClusters is too high."); + } + return {training_result, assignments}; + } + IndexVector getCounts(IndexVector assignments, index k) const { IndexVector counts(k);