#include #include "gpr_proc/GprProcessingPipeline.hpp" #include "model/gpr_proc/GprAlgoParams.hpp" using namespace geopro::core; // 桩算法:加法器(每元素 + offset) static void stubAdd(const float* in, float* out, const GprTraceMeta& meta, const void* params) { const int n = meta.ntraces * meta.samples; const float offset = *static_cast(params); for (int i = 0; i < n; ++i) out[i] = in[i] + offset; } // 桩算法:乘法器(每元素 * scale) static void stubMul(const float* in, float* out, const GprTraceMeta& meta, const void* params) { const int n = meta.ntraces * meta.samples; const float scale = *static_cast(params); for (int i = 0; i < n; ++i) out[i] = in[i] * scale; } class GprPipelineTest : public ::testing::Test { protected: GprTraceMeta meta{4, 3, 1.0, 1.0, 0.0, 0.0, true, 0.12}; std::vector raw = {1,2,3, 4,5,6, 7,8,9, 10,11,12}; float addVal = 10.0f; float mulVal = 2.0f; }; TEST_F(GprPipelineTest, basicRun) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); pipe.registerNode({"mul", stubMul, &mulVal}); auto out = pipe.run(raw, meta); ASSERT_EQ(out.size(), raw.size()); // (raw + 10) * 2 EXPECT_FLOAT_EQ(out[0], 22.0f); // (1+10)*2 EXPECT_FLOAT_EQ(out[1], 24.0f); // (2+10)*2 EXPECT_FLOAT_EQ(out[11], 44.0f); // (12+10)*2 } TEST_F(GprPipelineTest, cachingNoDirty) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); auto out1 = pipe.run(raw, meta); auto out2 = pipe.run(raw, meta); // 第二次应复用缓存,结果相同 EXPECT_EQ(out1, out2); } TEST_F(GprPipelineTest, dirtyChain) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); pipe.registerNode({"mul", stubMul, &mulVal}); auto out1 = pipe.run(raw, meta); EXPECT_FLOAT_EQ(out1[0], 22.0f); // 只改 mul 参数,add 应复用缓存 mulVal = 3.0f; pipe.markDirty("mul"); auto out2 = pipe.run(raw, meta); EXPECT_FLOAT_EQ(out2[0], 33.0f); // (1+10)*3 // 改 add 参数,add 和 mul 都应重算 addVal = 5.0f; pipe.markDirty("add"); auto out3 = pipe.run(raw, meta); EXPECT_FLOAT_EQ(out3[0], 18.0f); // (1+5)*3 } TEST_F(GprPipelineTest, disableNode) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); pipe.registerNode({"mul", stubMul, &mulVal}); pipe.setNodeEnabled("add", false); auto out = pipe.run(raw, meta); // add 禁用,只跑 mul:raw * 2 EXPECT_FLOAT_EQ(out[0], 2.0f); EXPECT_FLOAT_EQ(out[11], 24.0f); } TEST_F(GprPipelineTest, nodeOutputDebug) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); pipe.run(raw, meta); const auto* p = pipe.nodeOutput("add"); ASSERT_NE(p, nullptr); EXPECT_FLOAT_EQ((*p)[0], 11.0f); } TEST_F(GprPipelineTest, sizeMismatchThrows) { GprProcessingPipeline pipe; pipe.registerNode({"add", stubAdd, &addVal}); std::vector bad(5); EXPECT_THROW(pipe.run(bad, meta), std::invalid_argument); }