123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- #include <iostream>
- #include <fstream>
- #include <vector>
- #include <ceres/ceres.h>
- #include <Eigen/Dense>
- #include <matplotlibcpp.h>
- #include <iomanip>
- namespace plt = matplotlibcpp;
- // 定义残差结构体
- struct SineResidual {
- SineResidual(double x, double y) : x_(x), y_(y) {}
-
- template <typename T>
- bool operator()(const T* const params, T* residual) const {
- // params[0]: amplitude (a)
- // params[1]: frequency (b)
- // params[2]: phase (c)
- // params[3]: offset (d)
- residual[0] = y_ - (params[0] * sin(params[1] * x_ + params[2]) + params[3]);
- return true;
- }
-
- private:
- const double x_;
- const double y_;
- };
- int main() {
- // 读取数据 实际
- std::ifstream file1("../f_data.csv");
- if (!file1.is_open()) {
- std::cerr << "Failed to open data file" << std::endl;
- return 1;
- }
- std::vector<double> x_data, y_actual, x_ideal,y_ideal;
- std::string line;
-
- // 跳过标题行
- std::getline(file1, line);
-
- while (std::getline(file1, line)) {
- double x, y;
- if (sscanf(line.c_str(), "%lf,%lf", &x, &y) == 2) {
- x_data.push_back(x);
- y_actual.push_back(y);
- }
- }
- file1.close();
- // 读取数据 理想
- std::ifstream file2("../f(1)_data.csv");
- if (!file2.is_open()) {
- std::cerr << "Failed to open data file" << std::endl;
- return 1;
- }
- // 跳过标题行
- std::getline(file2, line);
- while (std::getline(file2, line)) {
- double x, y;
- if (sscanf(line.c_str(), "%lf,%lf", &x, &y) == 2) {
- x_ideal.push_back(x);
- y_ideal.push_back(y);
- }
- }
- file2.close();
- // 初始参数猜测 [a, b, c, d]
- double params[2] = {5.0,1.0 };
-
- // 构建优化问题
- ceres::Problem problem;
- for (size_t i = 0; i < x_data.size(); ++i) {
- problem.AddResidualBlock( //向问题中添加误差项
- //使用自动求导,模板参数:误差类型、输出维度、输入维度
- new ceres::AutoDiffCostFunction<SineResidual, 1, 2>(
- new SineResidual(x_data[i], y_actual[i])),
- nullptr,
- params
- );
- }
- // 配置求解器
- ceres::Solver::Options options;
- //options.linear_solver_type = ceres::DENSE_QR;
- options.linear_solver_type = ceres::SPARSE_NORMAL_CHOLESKY;
- options.minimizer_progress_to_stdout = false;
-
- // 运行求解器
- ceres::Solver::Summary summary; //优化信息
- ceres::Solve(options, &problem, &summary); //开始优化
-
- // 输出结果
- // std::cout << summary.BriefReport() << "\n";
- // std::cout << "Final params: a=" << params[0] << " b=" << params[1]
- // << " c=" << params[2] << " d=" << params[3] << "\n";
- // 生成拟合曲线
- std::vector<double> y_fit;
- for (double x : x_data) {
- y_fit.push_back(params[0] * sin(params[1] * x + params[2]) + params[3]);
- }
- // // 生成理想曲线 (y = sin(x))
- // for (double x : x_data) {
- // y_ideal.push_back(sin(x));
- // }
- // 计算绝对差距和百分比差异
- double total_abs_diff = 0.0;
- double total_percent_diff = 0.0;
- int count = y_fit.size();
- for (size_t i = 0; i < count; ++i) {
- double abs_diff = std::abs(y_fit[i] - y_ideal[i]);
- double percent_diff = (abs_diff / std::abs(y_ideal[i])) * 100;
- // 处理y_ideal接近0时的特殊情况
- if (std::abs(y_ideal[i]) < 1e-6) {
- percent_diff = 0; // 避免除以零
- }
- total_abs_diff += abs_diff;
- total_percent_diff += percent_diff;
- }
- double mean_abs_diff = total_abs_diff / count;
- double mean_percent_diff = total_percent_diff / count;
- // 将评估指标转换为字符串
- std::ostringstream stats_text;
- stats_text << "Mean Absolute Difference: " << std::fixed << std::setprecision(4) << mean_abs_diff << "\n"
- << "Mean Percentage Difference: " << std::fixed << std::setprecision(2) << mean_percent_diff << "%";
- // 可视化
- plt::figure_size(1200, 800);
- plt::scatter(x_data, y_actual, 10, {{"color", "red"}, {"label", "Actual Data"}});
- plt::plot(x_data, y_fit, {{"color", "green"}, {"label", "Fitted Curve"}});
- plt::plot(x_ideal, y_ideal, {{"color", "black"}, {"label", "Ideal Sine"}});
- // 在图像上添加评估指标文本
- // plt::text(0.02, 0.95, stats_text.str(),
- // {{"color", "blue"},
- // {"fontsize", "12"},
- // {"bbox", "dict(facecolor='white', alpha=0.7, edgecolor='none')"}});
- plt::text(0.01, -0.5, stats_text.str());
- plt::xlabel("x");
- plt::ylabel("y");
- plt::title("Sine Curve Fitting using Ceres");
- plt::legend();
- plt::save("sine_fit.png");
- plt::show();
- return 0;
- }
|