From b4bb27f1e5690d8ce9913133e73b44c55d7d916b Mon Sep 17 00:00:00 2001 From: Huamonarch Date: Wed, 13 May 2026 15:20:51 +0800 Subject: [PATCH] feat: add ModelRegistry with JSON loading and composite model Add CompositeModel for combining base+noise models, ModelRegistry singleton with JSON-based model template loading, per-instance-key model isolation, and inline CSV/valve pair/composite syntax parsing in createModel. --- TestProject/RNG/model/CompositeModel.h | 15 +++ TestProject/RNG/model/ModelRegistry.cc | 126 +++++++++++++++++++++++++ TestProject/RNG/model/ModelRegistry.h | 38 ++++++++ 3 files changed, 179 insertions(+) create mode 100644 TestProject/RNG/model/CompositeModel.h create mode 100644 TestProject/RNG/model/ModelRegistry.cc create mode 100644 TestProject/RNG/model/ModelRegistry.h diff --git a/TestProject/RNG/model/CompositeModel.h b/TestProject/RNG/model/CompositeModel.h new file mode 100644 index 0000000..721582d --- /dev/null +++ b/TestProject/RNG/model/CompositeModel.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + +struct CompositeModel : IModel { + std::unique_ptr base; + std::unique_ptr noise; + + CompositeModel(std::unique_ptr b, std::unique_ptr n) + : base(std::move(b)), noise(std::move(n)) {} + + float evaluate(size_t t) override { + return base->evaluate(t) + noise->evaluate(t); + } +}; diff --git a/TestProject/RNG/model/ModelRegistry.cc b/TestProject/RNG/model/ModelRegistry.cc new file mode 100644 index 0000000..f2bacb7 --- /dev/null +++ b/TestProject/RNG/model/ModelRegistry.cc @@ -0,0 +1,126 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +ModelRegistry& ModelRegistry::instance() { + static ModelRegistry reg; + return reg; +} + +ModelRegistry::ModelRegistry() { + registerMode("constant", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("normal", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("linear", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("sine", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("uniform", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("spike", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("drift", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("csv", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("bool_random",[](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("bool_toggle",[](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("bool_csv", [](const json& p, float d) { return std::make_unique(p, d); }); + registerMode("valve_pair", [](const json& p, float d) { return std::make_unique(p, d); }); +} + +void ModelRegistry::registerMode(const std::string& mode, Ctor ctor) { + factory[mode] = std::move(ctor); +} + +void ModelRegistry::loadModels(const std::string& jsonPath) { + std::ifstream f(jsonPath); + if (!f.is_open()) throw std::runtime_error("Cannot open " + jsonPath); + json j; + f >> j; + for (auto& [name, def] : j["models"].items()) { + modelTemplates[name] = { def["mode"].get(), def.value("params", json::object()) }; + } +} + +std::unique_ptr ModelRegistry::createModel(const std::string& modelName, float defaultVal) { + // Composite: base+noise + auto plusPos = modelName.find('+'); + if (plusPos != std::string::npos) { + auto base = createModel(modelName.substr(0, plusPos), defaultVal); + auto noise = createModel(modelName.substr(plusPos + 1), defaultVal); + return std::make_unique(std::move(base), std::move(noise)); + } + + // Inline CSV: csv:file:col + if (modelName.rfind("csv:", 0) == 0) { + auto first = modelName.find(':', 4); + auto second = modelName.find(':', first + 1); + json p; + p["file"] = modelName.substr(4, first - 4); + p["column"] = std::stoi(modelName.substr(first + 1, second - first - 1)); + return factory["csv"](p, 0.0f); + } + + // Valve pair reference: pair_model:action_model + auto colonPos = modelName.find(':'); + if (colonPos != std::string::npos) { + std::string pairModel = modelName.substr(0, colonPos); + std::string actionModel = modelName.substr(colonPos + 1); + auto it = modelTemplates.find(pairModel); + if (it != modelTemplates.end() && it->second.mode == "valve_pair") { + auto model = factory["valve_pair"](it->second.params, defaultVal); + static_cast(model.get())->actionModelName = actionModel; + return model; + } + } + + // Simple model name lookup + auto it = modelTemplates.find(modelName); + if (it == modelTemplates.end()) { + it = modelTemplates.find("normal_tiny"); + } + auto fit = factory.find(it->second.mode); + if (fit == factory.end()) { + throw std::runtime_error("Unknown mode: " + it->second.mode); + } + return fit->second(it->second.params, defaultVal); +} + +IModel* ModelRegistry::getOrCreate(const std::string& spec, float defaultVal, + const std::string& instanceKey) { + std::string key = instanceKey.empty() ? (spec.empty() ? "default" : spec) : instanceKey; + auto it = instances.find(key); + if (it != instances.end()) return it->second.get(); + + auto model = createModel(spec.empty() ? "normal_tiny" : spec, defaultVal); + + // Track by model name (extract base name: strip composite/noise and pair suffixes) + std::string modelName = spec; + auto plusPos = modelName.find('+'); + if (plusPos != std::string::npos) modelName = modelName.substr(0, plusPos); + auto colonPos = modelName.find(':'); + if (colonPos != std::string::npos) modelName = modelName.substr(0, colonPos); + + byModelName[modelName].push_back(model.get()); + instances[key] = std::move(model); + return instances[key].get(); +} + +std::vector ModelRegistry::findByModelName(const std::string& modelName) { + auto it = byModelName.find(modelName); + if (it != byModelName.end()) return it->second; + return {}; +} + +// ValvePairModel::linkPeers defined here to avoid circular include +void ValvePairModel::linkPeers(ModelRegistry& reg) { + auto models = reg.findByModelName(actionModelName); + if (!models.empty()) action = models[0]; +} diff --git a/TestProject/RNG/model/ModelRegistry.h b/TestProject/RNG/model/ModelRegistry.h new file mode 100644 index 0000000..087b339 --- /dev/null +++ b/TestProject/RNG/model/ModelRegistry.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +using json = nlohmann::json; + +class ModelRegistry { +public: + using Ctor = std::function(const json& params, float defaultVal)>; + + static ModelRegistry& instance(); + + void loadModels(const std::string& jsonPath); + + IModel* getOrCreate(const std::string& tables1Spec, float defaultVal, + const std::string& instanceKey = ""); + + std::vector findByModelName(const std::string& modelName); + + void registerMode(const std::string& mode, Ctor ctor); + +private: + ModelRegistry(); + std::unique_ptr createModel(const std::string& modelName, float defaultVal); + + struct ModelDef { + std::string mode; + json params; + }; + std::map modelTemplates; + std::map> instances; + std::map> byModelName; + std::map factory; +};