geopro/tests/core/gpr_proc/test_pipeline.cpp

101 lines
3.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <gtest/gtest.h>
#include "gpr_proc/GprProcessingPipeline.hpp"
#include "model/gpr_proc/GprAlgoParams.hpp"
using namespace geopro::core;
// 桩算法:加法器(每元素 + offset
static void stubAdd(const float* in, float* out, const GprTraceMeta& meta, const void* params) {
const int n = meta.ntraces * meta.samples;
const float offset = *static_cast<const float*>(params);
for (int i = 0; i < n; ++i) out[i] = in[i] + offset;
}
// 桩算法:乘法器(每元素 * scale
static void stubMul(const float* in, float* out, const GprTraceMeta& meta, const void* params) {
const int n = meta.ntraces * meta.samples;
const float scale = *static_cast<const float*>(params);
for (int i = 0; i < n; ++i) out[i] = in[i] * scale;
}
class GprPipelineTest : public ::testing::Test {
protected:
GprTraceMeta meta{4, 3, 1.0, 1.0, 0.0, 0.0, true, 0.12};
std::vector<float> raw = {1,2,3, 4,5,6, 7,8,9, 10,11,12};
float addVal = 10.0f;
float mulVal = 2.0f;
};
TEST_F(GprPipelineTest, basicRun) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
pipe.registerNode({"mul", stubMul, &mulVal});
auto out = pipe.run(raw, meta);
ASSERT_EQ(out.size(), raw.size());
// (raw + 10) * 2
EXPECT_FLOAT_EQ(out[0], 22.0f); // (1+10)*2
EXPECT_FLOAT_EQ(out[1], 24.0f); // (2+10)*2
EXPECT_FLOAT_EQ(out[11], 44.0f); // (12+10)*2
}
TEST_F(GprPipelineTest, cachingNoDirty) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
auto out1 = pipe.run(raw, meta);
auto out2 = pipe.run(raw, meta);
// 第二次应复用缓存,结果相同
EXPECT_EQ(out1, out2);
}
TEST_F(GprPipelineTest, dirtyChain) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
pipe.registerNode({"mul", stubMul, &mulVal});
auto out1 = pipe.run(raw, meta);
EXPECT_FLOAT_EQ(out1[0], 22.0f);
// 只改 mul 参数add 应复用缓存
mulVal = 3.0f;
pipe.markDirty("mul");
auto out2 = pipe.run(raw, meta);
EXPECT_FLOAT_EQ(out2[0], 33.0f); // (1+10)*3
// 改 add 参数add 和 mul 都应重算
addVal = 5.0f;
pipe.markDirty("add");
auto out3 = pipe.run(raw, meta);
EXPECT_FLOAT_EQ(out3[0], 18.0f); // (1+5)*3
}
TEST_F(GprPipelineTest, disableNode) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
pipe.registerNode({"mul", stubMul, &mulVal});
pipe.setNodeEnabled("add", false);
auto out = pipe.run(raw, meta);
// add 禁用,只跑 mulraw * 2
EXPECT_FLOAT_EQ(out[0], 2.0f);
EXPECT_FLOAT_EQ(out[11], 24.0f);
}
TEST_F(GprPipelineTest, nodeOutputDebug) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
pipe.run(raw, meta);
const auto* p = pipe.nodeOutput("add");
ASSERT_NE(p, nullptr);
EXPECT_FLOAT_EQ((*p)[0], 11.0f);
}
TEST_F(GprPipelineTest, sizeMismatchThrows) {
GprProcessingPipeline pipe;
pipe.registerNode({"add", stubAdd, &addVal});
std::vector<float> bad(5);
EXPECT_THROW(pipe.run(bad, meta), std::invalid_argument);
}