merge plot into main

This commit is contained in:
陈浩南 2023-10-05 16:20:53 +08:00
parent ed60d9ab05
commit 29cdc32ec0
6 changed files with 264 additions and 179 deletions

View File

@ -22,16 +22,11 @@ find_package(TBB REQUIRED)
find_package(Matplot++ REQUIRED)
find_path(ZPP_BITS_INCLUDE_DIR zpp_bits.h REQUIRED)
add_executable(ufo src/unfold.cpp src/main.cpp)
add_executable(ufo src/unfold.cpp src/plot.cpp src/main.cpp)
target_include_directories(ufo PRIVATE ${PROJECT_SOURCE_DIR}/include ${ZPP_BITS_INCLUDE_DIR})
target_link_libraries(ufo PRIVATE
yaml-cpp Eigen3::Eigen fmt::fmt concurrencpp::concurrencpp HighFive_HighFive TBB::tbb Matplot++::matplot)
add_executable(plot src/plot.cpp)
target_include_directories(plot PRIVATE ${PROJECT_SOURCE_DIR}/include ${ZPP_BITS_INCLUDE_DIR})
target_link_libraries(plot PRIVATE
yaml-cpp Eigen3::Eigen fmt::fmt concurrencpp::concurrencpp HighFive_HighFive TBB::tbb Matplot++::matplot)
get_property(ImportedTargets DIRECTORY "${CMAKE_SOURCE_DIR}" PROPERTY IMPORTED_TARGETS)
message("Imported targets: ${ImportedTargets}")
message("List of compile features: ${CMAKE_CXX_COMPILE_FEATURES}")

60
include/ufo/plot.hpp Normal file
View File

@ -0,0 +1,60 @@
# pragma once
# include <ufo/unfold.hpp>
namespace ufo
{
class PlotSolver : public Solver
{
public:
struct InputType
{
Eigen::Matrix3d PrimativeCell;
struct FigureConfigType
{
std::vector<std::vector<Eigen::Vector3d>> Qpoints;
std::pair<unsigned, unsigned> Resolution;
std::pair<double, double> Range;
std::string Filename;
};
std::vector<FigureConfigType> Figures;
struct SourceType : public UnfoldSolver::OutputType
{
SourceType(std::string filename);
SourceType() = default;
};
std::string SourceFilename;
SourceType Source;
InputType(std::string config_file);
};
protected:
InputType Input_;
public:
PlotSolver(std::string config_file);
PlotSolver& operator()() override;
// 根据 q 点路径, 搜索要使用的 q 点
static std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>> search_qpoints
(
const std::pair<Eigen::Vector3d, Eigen::Vector3d>& path,
const decltype(InputType::SourceType::QPointData)& available_qpoints,
double threshold, bool exclude_endpoint = false
);
// 根据搜索到的 q 点, 计算每个点的数值
static std::vector<std::vector<double>> calculate_values
(
const std::vector<std::pair<Eigen::Vector3d, Eigen::Vector3d>>& path,
const std::vector<std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>>>& qpoints,
const decltype(InputType::FigureConfigType::Resolution)& resolution,
const decltype(InputType::FigureConfigType::Range)& range
);
// 根据数值, 画图
static void plot
(
const std::vector<std::vector<double>>& values,
const decltype(InputType::FigureConfigType::Filename)& filename
);
};
}

View File

@ -50,7 +50,7 @@ namespace ufo
{
public:
virtual Solver& operator()() = 0;
~Solver() = default;
virtual ~Solver() = default;
inline static concurrencpp::generator<std::pair<Eigen::Vector<unsigned, 3>, unsigned>>
triplet_sequence(Eigen::Vector<unsigned, 3> range)

View File

@ -1,9 +1,13 @@
# include <ufo/unfold.hpp>
# include <ufo/plot.hpp>
int main(int argc, const char** argv)
{
if (argc != 2)
throw std::runtime_error(fmt::format("Usage: {} config.yaml", argv[0]));
ufo::UnfoldSolver{argv[1]}();
throw std::runtime_error(fmt::format("Usage: {} task config.yaml", argv[0]));
if (argv[1] == std::string("unfold"))
ufo::UnfoldSolver{argv[1]}();
else if (argv[1] == std::string("plot"))
ufo::PlotSolver{argv[1]}();
else
throw std::runtime_error(fmt::format("Unknown task: {}", argv[1]));
}

View File

@ -1,171 +1,215 @@
# include <ufo/ufo.impl.hpp>
# include <ufo/plot.hpp>
// 要被用来画图的路径
std::vector<Eigen::Vector3d> Qpoints =
namespace ufo
{
{ 0, 0, 0 },
{ 0.5, 0, 0 },
{ 1. / 3, 1. / 3, 0 },
{ 0, 0, 0 },
{ 0, 0, 0.5 },
{ 0.5, 0, 0.5 },
{ 1. / 3, 1. / 3, 0.5 },
{ 0, 0, 0.5 }
};
double Threshold = 0.001;
struct Point
{
Eigen::Vector3d QPoint;
Eigen::VectorXd Frequency, Weight;
double Distance;
};
int main(int argc, char** argv)
{
if (argc != 2)
throw std::runtime_error("Usage: plot output.zpp");
Output output(argv[1]);
std::vector<Point> Points;
double total_distance = 0;
// 对于每一条路径进行搜索
for (unsigned i = 0; i < Qpoints.size() - 1; i++)
PlotSolver::InputType::SourceType::SourceType(std::string filename)
{
std::vector<Point> point_of_this_path;
// 对于 output 中的每一个点, 检查这个点是否在路径上. 如果在, 把它加入到 point_of_this_path 中
for (auto& qpoint : output.QPointData)
auto input = std::ifstream(filename, std::ios::binary | std::ios::in);
input.exceptions(std::ios::badbit | std::ios::failbit);
static_assert(sizeof(std::byte) == sizeof(char));
std::vector<std::byte> data;
{
std::vector<char> string(std::istreambuf_iterator<char>(input), {});
data.assign
(
reinterpret_cast<std::byte*>(string.data()),
reinterpret_cast<std::byte*>(string.data() + string.size())
);
}
auto in = zpp::bits::in(data);
UnfoldSolver::OutputType output;
in(output).or_throw();
static_cast<UnfoldSolver::OutputType&>(*this) = std::move(output);
}
PlotSolver::InputType::InputType(std::string config_file)
{
auto input = YAML::LoadFile(config_file);
for (unsigned i = 0; i < 3; i++)
for (unsigned j = 0; j < 3; j++)
PrimativeCell(i, j) = input["PrimativeCell"][i][j].as<double>();
for (auto& figure : input["Figures"].as<std::vector<YAML::Node>>())
{
Figures.emplace_back();
auto qpoints = figure["Qpoints"]
.as<std::vector<std::vector<std::vector<double>>>>();
for (auto& line : qpoints)
{
Figures.back().Qpoints.emplace_back();
for (auto& point : line)
Figures.back().Qpoints.back().emplace_back(point.at(0), point.at(1), point.at(2));
if (Figures.back().Qpoints.back().size() < 2)
throw std::runtime_error("Not enough points in a line");
}
if (Figures.back().Qpoints.size() < 1)
throw std::runtime_error("Not enough lines in a figure");
Figures.back().Resolution = figure["Resolution"].as<std::pair<unsigned, unsigned>>();
Figures.back().Range = figure["Range"].as<std::pair<double, double>>();
Figures.back().Filename = figure["Filename"].as<std::string>();
}
SourceFilename = input["SourceFilename"].as<std::string>();
Source = SourceType(SourceFilename);
}
PlotSolver::PlotSolver(std::string config_file) : Input_(config_file) {}
PlotSolver& PlotSolver::operator()()
{
for (auto& figure : Input_.Figures)
{
// 外层表示不同的线段的端点,内层表示这个线段上的 q 点
std::vector<std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>>> qpoints;
std::vector<std::pair<Eigen::Vector3d, Eigen::Vector3d>> lines;
for (auto& path : figure.Qpoints)
for (unsigned i = 0; i < path.size() - 1; i++)
{
lines.emplace_back(path[i], path[i + 1]);
qpoints.push_back(search_qpoints
(
lines.back(), Input_.Source.QPointData,
0.001, i != path.size() - 2
));
}
auto values = calculate_values(lines, qpoints, figure.Resolution, figure.Range);
plot(values, figure.Filename);
}
return *this;
}
std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>> PlotSolver::search_qpoints
(
const std::pair<Eigen::Vector3d, Eigen::Vector3d>& path,
const decltype(InputType::SourceType::QPointData)& available_qpoints,
double threshold, bool exclude_endpoint
)
{
std::multimap<double, std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>> selected_qpoints;
// 对于 output 中的每一个点, 检查这个点是否在路径上. 如果在, 把它加入到 selected_qpoints 中
for (auto& qpoint : available_qpoints)
{
// 计算三点围成的三角形的面积的两倍
auto area = (Qpoints[i + 1] - Qpoints[i]).cross(qpoint.QPoint - Qpoints[i]).norm();
auto area = (path.second - path.first).cross(qpoint.QPoint - path.first).norm();
// 计算这个点到前两个点所在直线的距离
auto distance = area / (Qpoints[i + 1] - Qpoints[i]).norm();
auto distance = area / (path.second - path.first).norm();
// 如果这个点到前两个点所在直线的距离小于阈值, 则认为这个点在路径上
if (distance < Threshold)
if (distance < threshold)
{
// 计算这个点到前两个点的距离, 两个距离都应该小于两点之间的距离
auto distance1 = (qpoint.QPoint - Qpoints[i]).norm();
auto distance2 = (qpoint.QPoint - Qpoints[i + 1]).norm();
auto distance3 = (Qpoints[i + 1] - Qpoints[i]).norm();
if (distance1 < distance3 + Threshold && distance2 < distance3 + Threshold)
// 如果这个点在终点处, 且这条路径不是最后一条, 则不加入
if (distance2 > Threshold || i == Qpoints.size() - 2)
{
auto& _ = point_of_this_path.emplace_back();
_.QPoint = qpoint.QPoint;
_.Distance = distance1;
_.Frequency.resize(qpoint.ModeData.size());
_.Weight.resize(qpoint.ModeData.size());
for (unsigned j = 0; j < qpoint.ModeData.size(); j++)
{
_.Frequency(j) = qpoint.ModeData[j].Frequency;
_.Weight(j) = qpoint.ModeData[j].Weight;
}
}
auto distance1 = (qpoint.QPoint - path.first).norm();
auto distance2 = (qpoint.QPoint - path.second).norm();
auto distance3 = (path.second - path.first).norm();
if (distance1 < distance3 + threshold && distance2 < distance3 + threshold)
// 如果这个点不在终点处, 或者不排除终点, 则加入
if (distance2 > threshold || !exclude_endpoint)
selected_qpoints.emplace(distance1, std::ref(qpoint));
}
}
// 对筛选结果排序
std::sort(point_of_this_path.begin(), point_of_this_path.end(),
[](const Point& a, const Point& b) { return a.Distance < b.Distance; });
// 去除非常接近的点
for (unsigned j = 1; j < point_of_this_path.size(); j++)
while
(
j < point_of_this_path.size()
&& point_of_this_path[j].Distance - point_of_this_path[j - 1].Distance < Threshold
)
point_of_this_path.erase(point_of_this_path.begin() + j);
// 将结果加入
for (auto& point : point_of_this_path)
Points.emplace_back(point.QPoint, point.Frequency, point.Weight, point.Distance + total_distance);
total_distance += (Qpoints[i + 1] - Qpoints[i]).norm();
for (auto it = selected_qpoints.begin(); it != selected_qpoints.end();)
{
auto next = std::next(it);
if (next == selected_qpoints.end())
break;
else if (next->first - it->first < threshold)
selected_qpoints.erase(next);
else
it = next;
}
if (selected_qpoints.empty())
throw std::runtime_error("No q points found");
std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>> result;
for (auto& qpoint : selected_qpoints)
result.push_back(qpoint.second);
return result;
}
// 打印结果看一下
for (auto& point : Points)
std::cout << point.Distance << " " << point.QPoint.transpose() << std::endl;
// 对结果插值
std::vector<Point> interpolated_points;
for (unsigned i = 0; i < 1024; i++)
std::vector<std::vector<double>> PlotSolver::calculate_values
(
const std::vector<std::pair<Eigen::Vector3d, Eigen::Vector3d>>& path,
const std::vector<std::vector<std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>>>& qpoints,
const decltype(InputType::FigureConfigType::Resolution)& resolution,
const decltype(InputType::FigureConfigType::Range)& range
)
{
auto current_distance = i * total_distance / 1024;
auto& _ = interpolated_points.emplace_back();
_.Distance = current_distance;
auto it = std::lower_bound(Points.begin(), Points.end(), current_distance,
[](const Point& a, double b) { return a.Distance < b; });
// 如果是开头或者结尾, 直接赋值, 否则插值
if (it == Points.begin())
// 整理输入
std::map<double, std::reference_wrapper<const UnfoldSolver::OutputType::QPointDataType>> qpoints_with_distance;
double total_distance = 0;
for (unsigned i = 0; i < path.size(); i++)
{
_.Frequency = Points.front().Frequency;
_.Weight = Points.front().Weight;
for (auto& _ : qpoints[i])
qpoints_with_distance.emplace(total_distance + (_.get().QPoint - path[i].first).norm(), _);
total_distance += (path[i].second - path[i].first).norm();
}
else if (it == Points.end() - 1)
{
_.Frequency = Points.back().Frequency;
_.Weight = Points.back().Weight;
}
else
{
_.Frequency = (it->Frequency * (it->Distance - current_distance)
+ (it - 1)->Frequency * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance);
_.Weight = (it->Weight * (it->Distance - current_distance)
+ (it - 1)->Weight * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance);
}
}
// 将结果对应到像素上的值
std::vector<std::vector<double>> weight(1024, std::vector<double>(400, 0));
for (auto& point : interpolated_points)
// 插值
std::vector<std::vector<double>> values;
auto blend = []
(
const UnfoldSolver::OutputType::QPointDataType& a,
const UnfoldSolver::OutputType::QPointDataType& b,
double ratio, unsigned resolution, std::pair<double, double> range
) -> std::vector<double>
{
// 计算插值结果
std::vector<double> frequency, weight;
for (unsigned i = 0; i < a.ModeData.size(); i++)
{
frequency.push_back(a.ModeData[i].Frequency * ratio + b.ModeData[i].Frequency * (1 - ratio));
weight.push_back(a.ModeData[i].Weight * ratio + b.ModeData[i].Weight * (1 - ratio));
}
std::vector<double> result(resolution);
for (unsigned i = 0; i < frequency.size(); i++)
{
int index = (frequency[i] - range.first) / (range.second - range.first) * resolution;
if (index >= 0 && index < static_cast<int>(resolution))
result[index] += weight[i];
}
return result;
};
for (unsigned i = 0; i < resolution.first; i++)
{
auto current_distance = total_distance * i / resolution.first;
auto it = qpoints_with_distance.lower_bound(current_distance);
if (it == qpoints_with_distance.begin())
values.push_back(blend(it->second.get(), it->second.get(), 1, resolution.second, range));
else if (it == qpoints_with_distance.end())
values.push_back(blend(std::prev(it)->second.get(), std::prev(it)->second.get(), 1, resolution.second,
range));
else
values.push_back(blend
(
std::prev(it)->second.get(), it->second.get(),
(current_distance - std::prev(it)->first) / (it->first - std::prev(it)->first),
resolution.second, range)
);
}
return values;
}
void PlotSolver::plot
(
const std::vector<std::vector<double>>& values,
const decltype(InputType::FigureConfigType::Filename)& filename
)
{
int x = point.Distance / total_distance * 1024;
if (x < 0)
x = 0;
else if (x >= 1024)
x = 1023;
for (unsigned i = 0; i < point.Frequency.size(); i++)
{
auto y = (point.Frequency(i) + 5) * 10;
if (y < 0)
y = 0;
else if (y >= 400)
y = 399;
weight[x][y] += point.Weight(i);
}
std::vector<std::vector<double>>
r(values[0].size(), std::vector<double>(values.size(), 0)),
g(values[0].size(), std::vector<double>(values.size(), 0)),
b(values[0].size(), std::vector<double>(values.size(), 0));
for (unsigned i = 0; i < values[0].size(); i++)
for (unsigned j = 0; j < values.size(); j++)
{
r[i][j] = 255;
g[i][j] = 255 - values[j][i] * 2 * 255;
if (g[i][j] < 0)
g[i][j] = 0;
b[i][j] = 255 - values[j][i] * 2 * 255;
if (b[i][j] < 0)
b[i][j] = 0;
}
auto f = matplot::figure(true);
auto ax = f->current_axes();
ax->image(std::tie(r, g, b));
ax->y_axis().reverse(false);
f->save(filename);
}
std::ofstream fout("weight.txt");
for (unsigned i = 0; i < 400; i++)
fout << fmt::format(" {:.6f}", i * 0.1 - 5);
fout << std::endl;
for (unsigned i = 0; i < 1024; i++)
{
fout << fmt::format("{:.6f} ", total_distance / 1024 * i);
for (unsigned j = 0; j < 400; j++)
fout << fmt::format("{:.6f} ", weight[i][j]);
fout << std::endl;
}
std::vector<std::vector<double>>
r(400, std::vector<double>(1024, 0)),
g(400, std::vector<double>(1024, 0)),
b(400, std::vector<double>(1024, 0));
for (unsigned i = 0; i < 400; i++)
for (unsigned j = 0; j < 1024; j++)
{
r[i][j] = 255;
g[i][j] = 255 - weight[j][i] * 2 * 255;
if (g[i][j] < 0)
g[i][j] = 0;
b[i][j] = 255 - weight[j][i] * 2 * 255;
if (b[i][j] < 0)
b[i][j] = 0;
}
auto f = matplot::figure(true);
auto ax = f->current_axes();
ax->image(std::tie(r, g, b));
ax->y_axis().reverse(false);
f->show();
}

View File

@ -478,21 +478,3 @@ namespace ufo
return output;
}
}
// inline Output::Output(std::string filename)
// {
// auto input = std::ifstream(filename, std::ios::binary | std::ios::in);
// input.exceptions(std::ios::badbit | std::ios::failbit);
// std::vector<std::byte> data;
// {
// std::vector<char> string(std::istreambuf_iterator<char>(input), {});
// data.assign
// (
// reinterpret_cast<std::byte*>(string.data()),
// reinterpret_cast<std::byte*>(string.data() + string.size())
// );
// }
// auto in = zpp::bits::in(data);
// in(*this).or_throw();
// }