101 lines
3.1 KiB
C++
101 lines
3.1 KiB
C++
#include <gtest/gtest.h>
|
||
#include "gpr_proc/GprProcessingPipeline.hpp"
|
||
#include "model/gpr_proc/GprAlgoParams.hpp"
|
||
|
||
using namespace geopro::core;
|
||
|
||
// 桩算法:加法器(每元素 + offset)
|
||
static void stubAdd(const float* in, float* out, const GprTraceMeta& meta, const void* params) {
|
||
const int n = meta.ntraces * meta.samples;
|
||
const float offset = *static_cast<const float*>(params);
|
||
for (int i = 0; i < n; ++i) out[i] = in[i] + offset;
|
||
}
|
||
|
||
// 桩算法:乘法器(每元素 * scale)
|
||
static void stubMul(const float* in, float* out, const GprTraceMeta& meta, const void* params) {
|
||
const int n = meta.ntraces * meta.samples;
|
||
const float scale = *static_cast<const float*>(params);
|
||
for (int i = 0; i < n; ++i) out[i] = in[i] * scale;
|
||
}
|
||
|
||
class GprPipelineTest : public ::testing::Test {
|
||
protected:
|
||
GprTraceMeta meta{4, 3, 1.0, 1.0, 0.0, 0.0, true, 0.12};
|
||
std::vector<float> raw = {1,2,3, 4,5,6, 7,8,9, 10,11,12};
|
||
float addVal = 10.0f;
|
||
float mulVal = 2.0f;
|
||
};
|
||
|
||
TEST_F(GprPipelineTest, basicRun) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
pipe.registerNode({"mul", stubMul, &mulVal});
|
||
|
||
auto out = pipe.run(raw, meta);
|
||
ASSERT_EQ(out.size(), raw.size());
|
||
// (raw + 10) * 2
|
||
EXPECT_FLOAT_EQ(out[0], 22.0f); // (1+10)*2
|
||
EXPECT_FLOAT_EQ(out[1], 24.0f); // (2+10)*2
|
||
EXPECT_FLOAT_EQ(out[11], 44.0f); // (12+10)*2
|
||
}
|
||
|
||
TEST_F(GprPipelineTest, cachingNoDirty) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
|
||
auto out1 = pipe.run(raw, meta);
|
||
auto out2 = pipe.run(raw, meta);
|
||
// 第二次应复用缓存,结果相同
|
||
EXPECT_EQ(out1, out2);
|
||
}
|
||
|
||
TEST_F(GprPipelineTest, dirtyChain) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
pipe.registerNode({"mul", stubMul, &mulVal});
|
||
|
||
auto out1 = pipe.run(raw, meta);
|
||
EXPECT_FLOAT_EQ(out1[0], 22.0f);
|
||
|
||
// 只改 mul 参数,add 应复用缓存
|
||
mulVal = 3.0f;
|
||
pipe.markDirty("mul");
|
||
auto out2 = pipe.run(raw, meta);
|
||
EXPECT_FLOAT_EQ(out2[0], 33.0f); // (1+10)*3
|
||
|
||
// 改 add 参数,add 和 mul 都应重算
|
||
addVal = 5.0f;
|
||
pipe.markDirty("add");
|
||
auto out3 = pipe.run(raw, meta);
|
||
EXPECT_FLOAT_EQ(out3[0], 18.0f); // (1+5)*3
|
||
}
|
||
|
||
TEST_F(GprPipelineTest, disableNode) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
pipe.registerNode({"mul", stubMul, &mulVal});
|
||
|
||
pipe.setNodeEnabled("add", false);
|
||
auto out = pipe.run(raw, meta);
|
||
// add 禁用,只跑 mul:raw * 2
|
||
EXPECT_FLOAT_EQ(out[0], 2.0f);
|
||
EXPECT_FLOAT_EQ(out[11], 24.0f);
|
||
}
|
||
|
||
TEST_F(GprPipelineTest, nodeOutputDebug) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
|
||
pipe.run(raw, meta);
|
||
const auto* p = pipe.nodeOutput("add");
|
||
ASSERT_NE(p, nullptr);
|
||
EXPECT_FLOAT_EQ((*p)[0], 11.0f);
|
||
}
|
||
|
||
TEST_F(GprPipelineTest, sizeMismatchThrows) {
|
||
GprProcessingPipeline pipe;
|
||
pipe.registerNode({"add", stubAdd, &addVal});
|
||
std::vector<float> bad(5);
|
||
EXPECT_THROW(pipe.run(bad, meta), std::invalid_argument);
|
||
}
|