main.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #include <iostream>
  2. #include <fstream>
  3. #include <vector>
  4. #include <ceres/ceres.h>
  5. #include <Eigen/Dense>
  6. #include <matplotlibcpp.h>
  7. #include <iomanip>
  8. namespace plt = matplotlibcpp;
  9. // 定义残差结构体
  10. struct SineResidual {
  11. SineResidual(double x, double y) : x_(x), y_(y) {}
  12. template <typename T>
  13. bool operator()(const T* const params, T* residual) const {
  14. // params[0]: amplitude (a)
  15. // params[1]: frequency (b)
  16. // params[2]: phase (c)
  17. // params[3]: offset (d)
  18. residual[0] = y_ - (params[0] * sin(params[1] * x_ + params[2]) + params[3]);
  19. return true;
  20. }
  21. private:
  22. const double x_;
  23. const double y_;
  24. };
  25. int main() {
  26. // 读取数据 实际
  27. std::ifstream file1("../f_data.csv");
  28. if (!file1.is_open()) {
  29. std::cerr << "Failed to open data file" << std::endl;
  30. return 1;
  31. }
  32. std::vector<double> x_data, y_actual, x_ideal,y_ideal;
  33. std::string line;
  34. // 跳过标题行
  35. std::getline(file1, line);
  36. while (std::getline(file1, line)) {
  37. double x, y;
  38. if (sscanf(line.c_str(), "%lf,%lf", &x, &y) == 2) {
  39. x_data.push_back(x);
  40. y_actual.push_back(y);
  41. }
  42. }
  43. file1.close();
  44. // 读取数据 理想
  45. std::ifstream file2("../f(1)_data.csv");
  46. if (!file2.is_open()) {
  47. std::cerr << "Failed to open data file" << std::endl;
  48. return 1;
  49. }
  50. // 跳过标题行
  51. std::getline(file2, line);
  52. while (std::getline(file2, line)) {
  53. double x, y;
  54. if (sscanf(line.c_str(), "%lf,%lf", &x, &y) == 2) {
  55. x_ideal.push_back(x);
  56. y_ideal.push_back(y);
  57. }
  58. }
  59. file2.close();
  60. // 初始参数猜测 [a, b, c, d]
  61. double params[2] = {5.0,1.0 };
  62. // 构建优化问题
  63. ceres::Problem problem;
  64. for (size_t i = 0; i < x_data.size(); ++i) {
  65. problem.AddResidualBlock( //向问题中添加误差项
  66. //使用自动求导,模板参数:误差类型、输出维度、输入维度
  67. new ceres::AutoDiffCostFunction<SineResidual, 1, 2>(
  68. new SineResidual(x_data[i], y_actual[i])),
  69. nullptr,
  70. params
  71. );
  72. }
  73. // 配置求解器
  74. ceres::Solver::Options options;
  75. //options.linear_solver_type = ceres::DENSE_QR;
  76. options.linear_solver_type = ceres::SPARSE_NORMAL_CHOLESKY;
  77. options.minimizer_progress_to_stdout = false;
  78. // 运行求解器
  79. ceres::Solver::Summary summary; //优化信息
  80. ceres::Solve(options, &problem, &summary); //开始优化
  81. // 输出结果
  82. // std::cout << summary.BriefReport() << "\n";
  83. // std::cout << "Final params: a=" << params[0] << " b=" << params[1]
  84. // << " c=" << params[2] << " d=" << params[3] << "\n";
  85. // 生成拟合曲线
  86. std::vector<double> y_fit;
  87. for (double x : x_data) {
  88. y_fit.push_back(params[0] * sin(params[1] * x + params[2]) + params[3]);
  89. }
  90. // // 生成理想曲线 (y = sin(x))
  91. // for (double x : x_data) {
  92. // y_ideal.push_back(sin(x));
  93. // }
  94. // 计算绝对差距和百分比差异
  95. double total_abs_diff = 0.0;
  96. double total_percent_diff = 0.0;
  97. int count = y_fit.size();
  98. for (size_t i = 0; i < count; ++i) {
  99. double abs_diff = std::abs(y_fit[i] - y_ideal[i]);
  100. double percent_diff = (abs_diff / std::abs(y_ideal[i])) * 100;
  101. // 处理y_ideal接近0时的特殊情况
  102. if (std::abs(y_ideal[i]) < 1e-6) {
  103. percent_diff = 0; // 避免除以零
  104. }
  105. total_abs_diff += abs_diff;
  106. total_percent_diff += percent_diff;
  107. }
  108. double mean_abs_diff = total_abs_diff / count;
  109. double mean_percent_diff = total_percent_diff / count;
  110. // 将评估指标转换为字符串
  111. std::ostringstream stats_text;
  112. stats_text << "Mean Absolute Difference: " << std::fixed << std::setprecision(4) << mean_abs_diff << "\n"
  113. << "Mean Percentage Difference: " << std::fixed << std::setprecision(2) << mean_percent_diff << "%";
  114. // 可视化
  115. plt::figure_size(1200, 800);
  116. plt::scatter(x_data, y_actual, 10, {{"color", "red"}, {"label", "Actual Data"}});
  117. plt::plot(x_data, y_fit, {{"color", "green"}, {"label", "Fitted Curve"}});
  118. plt::plot(x_ideal, y_ideal, {{"color", "black"}, {"label", "Ideal Sine"}});
  119. // 在图像上添加评估指标文本
  120. // plt::text(0.02, 0.95, stats_text.str(),
  121. // {{"color", "blue"},
  122. // {"fontsize", "12"},
  123. // {"bbox", "dict(facecolor='white', alpha=0.7, edgecolor='none')"}});
  124. plt::text(0.01, -0.5, stats_text.str());
  125. plt::xlabel("x");
  126. plt::ylabel("y");
  127. plt::title("Sine Curve Fitting using Ceres");
  128. plt::legend();
  129. plt::save("sine_fit.png");
  130. plt::show();
  131. return 0;
  132. }