修复bug, 增加功能

This commit is contained in:
陈浩南 2023-09-27 19:24:41 +08:00
parent e042ae4c15
commit d157a45201
6 changed files with 121 additions and 72 deletions

3
.gitignore vendored
View File

@ -1,9 +1,10 @@
.vscode
.direnv
test/*
main
plot
out.yaml
.ccls-cache
.cache
build/
test/*
!test/14.2.2.yaml

View File

@ -4,12 +4,12 @@
outputs = inputs:
let
pkgs = inputs.nixpkgs.legacyPackages.x86_64-linux;
# pkgs = import inputs.nixpkgs
# {
# localSystem = { system = "x86_64-linux"; gcc = { arch = "alderlake"; tune = "alderlake"; }; };
# config.allowUnfree = true;
# };
# pkgs = inputs.nixpkgs.legacyPackages.x86_64-linux;
pkgs = import inputs.nixpkgs
{
localSystem = { system = "x86_64-linux"; gcc = { arch = "alderlake"; tune = "alderlake"; }; };
config.allowUnfree = true;
};
localPackages = import "${inputs.nixos}/local/pkgs" { inherit pkgs; inherit (inputs.nixpkgs) lib; };
in
{

View File

@ -7,10 +7,10 @@ inline Input::Input(std::string filename)
{
// read main input file
{
auto node = YAML::LoadFile(filename);
for (unsigned i = 0; i < 3; i++)
for (unsigned j = 0; j < 3; j++)
PrimativeCell(i, j) = node["PrimativeCell"][i][j].as<double>();
auto node = YAML::LoadFile(filename);
for (unsigned i = 0; i < 3; i++)
for (unsigned j = 0; j < 3; j++)
PrimativeCell(i, j) = node["PrimativeCell"][i][j].as<double>();
for (unsigned i = 0; i < 3; i++)
SuperCellMultiplier(i) = node["SuperCellMultiplier"][i].as<int>();
@ -26,28 +26,61 @@ inline Input::Input(std::string filename)
for (unsigned i = 0; i < 3; i++)
PrimativeCellBasisNumber(i) = node["PrimativeCellBasisNumber"][i].as<int>();
AtomPositionInputFile.FileName = node["AtomPositionInputFile"]["FileName"].as<std::string>();
AtomPositionInputFile.Format = node["AtomPositionInputFile"]["Format"].as<std::string>();
auto read_file_config = [filename](YAML::Node source, InputOutputFile_& config)
{
if (auto _ = source["SameAsConfigFile"])
{
auto __ = _.as<bool>();
config.ExtraParameters["SameAsConfigFile"] = __;
if (__)
{
config.FileName = filename;
config.Format = "yaml";
return;
}
}
config.FileName = source["FileName"].as<std::string>();
config.Format = source["Format"].as<std::string>();
if (auto _ = source["RelativeToConfigFile"])
{
auto __ = _.as<bool>();
config.ExtraParameters["RelativeToConfigFile"] = __;
if (__)
config.FileName = std::filesystem::path(filename).parent_path() / config.FileName;
}
};
read_file_config(node["AtomPositionInputFile"], AtomPositionInputFile);
if (!std::set<std::string>{"yaml"}.contains(AtomPositionInputFile.Format))
throw std::runtime_error(fmt::format
("Unknown format: {}, should be \"yaml\".", AtomPositionInputFile.Format));
QPointDataInputFile.FileName = node["QPointDataInputFile"]["FileName"].as<std::string>();
QPointDataInputFile.Format = node["QPointDataInputFile"]["Format"].as<std::string>();
("Unknown AtomPositionInputFile.Format: {}, should be \"yaml\".", AtomPositionInputFile.Format));
read_file_config(node["QPointDataInputFile"], QPointDataInputFile);
if (!std::set<std::string>{"yaml", "hdf5"}.contains(QPointDataInputFile.Format))
throw std::runtime_error(fmt::format
("Unknown format: {}, should be \"yaml\" or \"hdf5\".", QPointDataInputFile.Format));
("Unknown QPointDataInputFile.Format: {}, should be \"yaml\" or \"hdf5\".", QPointDataInputFile.Format));
if (auto value = node["QPointDataOutputFile"])
{
QPointDataOutputFile.resize(value.size());
for (unsigned i = 0; i < value.size(); i++)
{
auto& _ = QPointDataOutputFile.emplace_back();
_.FileName = value[i]["FileName"].as<std::string>();
_.Format = value[i]["Format"].as<std::string>();
if (!std::set<std::string>{"yaml", "yaml-human-readable", "zpp"}.contains(_.Format))
read_file_config(value[i], QPointDataOutputFile[i]);
if
(
QPointDataOutputFile[i].ExtraParameters.contains("SameAsConfigFile")
&& std::any_cast<bool>(QPointDataOutputFile[i].ExtraParameters["SameAsConfigFile"])
)
throw std::runtime_error("QPointDataOutputFile.SameAsConfigFile should not be set.");
if
(
!std::set<std::string>{"yaml", "yaml-human-readable", "zpp"}
.contains(QPointDataOutputFile[i].Format)
)
throw std::runtime_error(fmt::format
("Unknown format: {}, should be \"yaml\", \"yaml-human-readable\" or \"zpp\".", _.Format));
(
"Unknown QPointDataOutputFile[{}].Format: {}, should be \"yaml\", \"yaml-human-readable\" or \"zpp\".",
i, QPointDataOutputFile[i].Format
));
}
}
}
if (AtomPositionInputFile.Format == "yaml")

View File

@ -1,32 +1,5 @@
# include <ufo/ufo.impl.hpp>
class FloatVector
{
protected:
std::vector<double> Data_;
double LowerBound_, UpperBound_, Step_;
public:
FloatVector(double LowerBound, double UpperBound, double Step)
: LowerBound_(LowerBound), UpperBound_(UpperBound), Step_(Step), Data_((UpperBound - LowerBound) / Step + 1) {}
FloatVector(const FloatVector&) = default;
FloatVector(FloatVector&&) = default;
FloatVector& operator=(const FloatVector&) = default;
FloatVector& operator=(FloatVector&&) = default;
double& operator[](double i) { return Data_[static_cast<int>((i - LowerBound_) / Step_)]; }
double operator[](double i) const { return Data_[static_cast<int>((i - LowerBound_) / Step_)]; }
double lower_bound() const { return LowerBound_; }
double upper_bound() const { return UpperBound_; }
double step() const { return Step_; }
int size() const { return Data_.size(); }
std::map<double, double> to_map() const
{
std::map<double, double> result;
for (int i = 0; i < Data_.size(); i++)
result[LowerBound_ + i * Step_] = Data_[i];
return result;
}
};
// 要被用来画图的路径
std::vector<Eigen::Vector3d> Qpoints =
{
@ -41,6 +14,13 @@ std::vector<Eigen::Vector3d> Qpoints =
};
double Threshold = 0.001;
struct Point
{
Eigen::Vector3d QPoint;
Eigen::VectorXd Frequency, Weight;
double Distance;
};
int main(int argc, char** argv)
{
if (argc != 2)
@ -48,17 +28,9 @@ int main(int argc, char** argv)
Output output(argv[1]);
//
struct Point
{
Eigen::Vector3d QPoint;
FloatVector Weight;
double Distance;
};
std::vector<Point> Points;
double current_distance = 0;
double total_distance = 0;
// 对于每一条路径进行搜索
for (unsigned i = 0; i < Qpoints.size() - 1; i++)
{
@ -81,10 +53,16 @@ int main(int argc, char** argv)
// 如果这个点在终点处, 且这条路径不是最后一条, 则不加入
if (distance2 > Threshold || i == Qpoints.size() - 2)
{
auto& _ = point_of_this_path.emplace_back
(qpoint.QPoint, FloatVector{-5, 35, 0.1}, distance1);
for (auto& mode : qpoint.ModeData)
_.Weight[mode.Frequency] += mode.Weight;
auto& _ = point_of_this_path.emplace_back();
_.QPoint = qpoint.QPoint;
_.Distance = distance1;
_.Frequency.resize(qpoint.ModeData.size());
_.Weight.resize(qpoint.ModeData.size());
for (unsigned j = 0; j < qpoint.ModeData.size(); j++)
{
_.Frequency(j) = qpoint.ModeData[j].Frequency;
_.Weight(j) = qpoint.ModeData[j].Weight;
}
}
}
}
@ -101,7 +79,36 @@ int main(int argc, char** argv)
point_of_this_path.erase(point_of_this_path.begin() + j);
// 将结果加入
for (auto& point : point_of_this_path)
Points.emplace_back(point.QPoint, point.Weight, point.Distance + current_distance);
current_distance += (Qpoints[i + 1] - Qpoints[i]).norm();
Points.emplace_back(point.QPoint, point.Frequency, point.Weight, point.Distance + total_distance);
total_distance += (Qpoints[i + 1] - Qpoints[i]).norm();
}
// 对结果插值
std::vector<Point> interpolated_points;
for (unsigned i = 0; i < 1024; i++)
{
auto current_distance = i * total_distance / 1024;
auto& _ = interpolated_points.emplace_back();
_.Distance = current_distance;
// 如果是开头或者结尾, 直接赋值, 否则插值
if (current_distance < Points.front().Distance)
{
_.Frequency = Points.front().Frequency;
_.Weight = Points.front().Weight;
}
else if (current_distance > Points.back().Distance)
{
_.Frequency = Points.back().Frequency;
_.Weight = Points.back().Weight;
}
else
{
auto it = std::lower_bound(Points.begin(), Points.end(), current_distance,
[](const Point& a, double b) { return a.Distance < b; });
_.Frequency = (it->Frequency * (it->Distance - current_distance)
+ (it - 1)->Frequency * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance);
_.Weight = (it->Weight * (it->Distance - current_distance)
+ (it - 1)->Weight * (current_distance - (it - 1)->Distance)) / (it->Distance - (it - 1)->Distance);
}
}
}

View File

@ -90,7 +90,7 @@ int main(int argc, const char** argv)
{
// 当 SuperCellDeformation 不是单位矩阵时, input.QPointData[i_of_qpoint].QPoint 不一定在 reciprocal_primative_cell 中
// 需要首先将 q 点平移数个周期, 进入不包含 SuperCellDeformation 的超胞的倒格子中
auto qpoint_by_reciprocal_super_cell_in_modified_reciprocal_super_cell
auto qpoint_by_reciprocal_modified_super_cell_in_modified_reciprocal_super_cell
= !input.SuperCellDeformation ? input.QPointData[i_of_qpoint].QPoint : [&]
{
auto current_qpoint = input.QPointData[i_of_qpoint].QPoint;
@ -139,16 +139,12 @@ int main(int argc, const char** argv)
}
current_qpoint = min_score_qpoint;
}
return current_qpoint;
return input.SuperCellDeformation->inverse() * current_qpoint;
}();
for (auto [xyz_of_diff_of_sub_qpoint_by_reciprocal_modified_super_cell, i_of_sub_qpoint]
: triplet_sequence(input.SuperCellMultiplier))
{
auto& _ = output.QPointData.emplace_back();
// 这一步推导过程在计算 score 的函数中
auto qpoint_by_reciprocal_modified_super_cell_in_modified_reciprocal_super_cell =
input.SuperCellDeformation.value_or(Eigen::Matrix3d::Identity()).inverse()
* qpoint_by_reciprocal_super_cell_in_modified_reciprocal_super_cell;
auto reciprocal_modified_super_cell =
(input.SuperCellMultiplier.cast<double>().asDiagonal() * input.PrimativeCell).inverse().transpose();
// sub qpoint 的坐标,单位为埃^-1

12
test/14.2.2.yaml Normal file
View File

@ -0,0 +1,12 @@
PrimativeCell:
- [ 1.548010265167714, -2.681232429908652, 0.000000000000000 ] # a
- [ 1.548010265167714, 2.681232429908652, 0.000000000000000 ] # b
- [ 0.000000000000000, 0.000000000000000, 5.061224103556595 ] # c
SuperCellMultiplier: [ 3, 1, 1 ]
PrimativeCellBasisNumber: [ 8, 8, 8 ]
AtomPositionInputFile: { FileName: "/home/chn/Documents/lammps-SiC/14/14.2/14.2.2/14.2.2.4/band.yaml", Format: "yaml" }
QPointDataInputFile: { FileName: "/home/chn/Documents/lammps-SiC/14/14.2/14.2.2/14.2.2.4/band.yaml", Format: "yaml" }
QPointDataOutputFile:
- { FileName: "test/14.2.2.result.yaml", Format: "yaml-human-readable" }
- { FileName: "test/14.2.2.result.human-readable.yaml", Format: "yaml-human-readable" }
- { FileName: "test/14.2.2.result.zpp", Format: "zpp" }