Skip to content

Commit abc07b8

Browse files
authored
Merge pull request #331 from flucoma/feature/nnsvd-random-seeding
NNDSVD: Add seedable randomness and test
2 parents 212f74d + 0e1b34b commit abc07b8

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

include/flucoma/algorithms/public/NNDSVD.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ under the European Union’s Horizon 2020 research and innovation programme
1111
#pragma once
1212

1313
#include "../util/AlgorithmUtils.hpp"
14+
#include "../util/EigenRandom.hpp"
1415
#include "../util/FluidEigenMappings.hpp"
1516
#include "../../data/FluidIndex.hpp"
1617
#include "../../data/TensorTypes.hpp"
@@ -28,7 +29,7 @@ class NNDSVD
2829

2930
index process(RealMatrixView X, RealMatrixView W, RealMatrixView H,
3031
index minRank = 0, index maxRank = 200, double amount = 0.8,
31-
index method = 0) // 0 - NMF-SVD, 1 NNDSVDar, 2 NNDSVDa 3 NNDSVD
32+
index method = 0, index seed = -1) // 0 - NMF-SVD, 1 NNDSVDar, 2 NNDSVDa 3 NNDSVD
3233
{
3334
using namespace _impl;
3435
using namespace Eigen;
@@ -101,14 +102,16 @@ class NNDSVD
101102
WT.col(j) = u; // avoid scaling for NMF with normalized W
102103
HT.row(j) = lbd * v;
103104
}
104-
WT = WT.array().max(epsilon);
105-
HT = HT.array().max(epsilon);
105+
106+
double mean = XT.mean();
106107
if (method == 1)
107108
{
108109
auto Wrand =
109-
MatrixXd::Random(WT.rows(), WT.cols()).array().abs() / 100.0;
110+
EigenRandom<MatrixXd>(WT.rows(), WT.cols(), RandomSeed{seed},
111+
Range{epsilon, mean * 0.001});
110112
auto Hrand =
111-
MatrixXd::Random(HT.rows(), HT.cols()).array().abs() / 100.0;
113+
EigenRandom<MatrixXd>(HT.rows(), HT.cols(), RandomSeed{seed},
114+
Range{epsilon, mean * 0.001});
112115
WT = (WT.array() < epsilon).select(Wrand, WT);
113116
HT = (HT.array() < epsilon).select(Hrand, HT);
114117
}

include/flucoma/clients/nrt/NMFSeedClient.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum NMFSeedParamIndex {
3232
kMaxRank,
3333
kCoverage,
3434
kMethod,
35+
kRandomSeed,
3536
kFFT
3637
};
3738

@@ -46,6 +47,7 @@ constexpr auto NMFSeedParams =
4647
FloatParam("coverage", "Coverage", 0.5, Min(0), Max(1)),
4748
EnumParam("method", "Initialization Method", 0, "NMF-SVD",
4849
"NNDSVDar", "NNDSVDa", "NNDSVD"),
50+
LongParam("seed", "Random Seed", -1),
4951
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
5052

5153
class NMFSeedClient : public FluidBaseClient, public OfflineIn, public OfflineOut
@@ -100,9 +102,9 @@ class NMFSeedClient : public FluidBaseClient, public OfflineIn, public OfflineOu
100102

101103
auto nndsvd = algorithm::NNDSVD();
102104

103-
index rank = nndsvd.process(magnitude, outputFilters, outputEnvelopes,
104-
get<kMinRank>(), get<kMaxRank>(),
105-
get<kCoverage>(), get<kMethod>());
105+
index rank = nndsvd.process(
106+
magnitude, outputFilters, outputEnvelopes, get<kMinRank>(),
107+
get<kMaxRank>(), get<kCoverage>(), get<kMethod>(), get<kRandomSeed>());
106108

107109
auto filters = BufferAdaptor::Access{get<kFilters>().get()};
108110
Result resizeResult =

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ target_link_libraries(TestEnvelopeGate PRIVATE TestSignals)
129129
target_link_libraries(TestTransientSlice PRIVATE TestSignals)
130130

131131
add_test_executable(TestEigenRandom algorithms/util/TestEigenRandom.cpp)
132+
add_test_executable(TestNNDSVD algorithms/public/TestNNDSVD.cpp)
132133
add_test_executable(TestRTPGHI algorithms/util/TestRTPGHI.cpp)
133134

134135
include(CTest)
@@ -158,6 +159,7 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
158159
catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
159160
catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
160161
catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
162+
catch_discover_tests(TestNNDSVD WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
161163
catch_discover_tests(TestNMF WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
162164
catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
163165
catch_discover_tests(TestUMAP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include <flucoma/algorithms/public/NNDSVD.hpp>
3+
#include <flucoma/data/FluidTensor.hpp>
4+
#include <catch2/catch_all.hpp>
5+
#include <vector>
6+
7+
TEST_CASE("NNDSVD Mode 1 is repeatable with manually set random seed"){
8+
9+
using Tensor = fluid::FluidTensor<double,2>;
10+
using fluid::algorithm::NNDSVD;
11+
12+
// To test the effect of randomness in NNDSVD mode 1, there must be 0s in the input
13+
Tensor input = {{0,0,0},{0,0,0},{0,0,0}};
14+
15+
std::vector Ws(3, Tensor(3,3));
16+
std::vector Hs(3, Tensor(3,3));
17+
18+
NNDSVD algo;
19+
20+
algo.process(input,Ws[0],Hs[0],2,2,0.8,1, 42);
21+
algo.process(input,Ws[1],Hs[1],2,2,0.8,1, 42);
22+
algo.process(input,Ws[2],Hs[2],2,2,0.8,1, 4672);
23+
24+
using Catch::Matchers::RangeEquals;
25+
26+
REQUIRE_THAT(Ws[1],RangeEquals(Ws[0]));
27+
REQUIRE_THAT(Ws[1],!RangeEquals(Ws[2]));
28+
REQUIRE_THAT(Hs[1],RangeEquals(Hs[0]));
29+
REQUIRE_THAT(Hs[1],!RangeEquals(Hs[2]));
30+
}

0 commit comments

Comments
 (0)