eis/TestProject/RNG/model/ModelRegistry.cc

133 lines
5.6 KiB
C++
Raw Normal View History

#include <TestProject/RNG/model/ModelRegistry.h>
#include <TestProject/RNG/model/ConstantModel.h>
#include <TestProject/RNG/model/NormalModel.h>
#include <TestProject/RNG/model/LinearModel.h>
#include <TestProject/RNG/model/SineModel.h>
#include <TestProject/RNG/model/UniformModel.h>
#include <TestProject/RNG/model/SpikeModel.h>
#include <TestProject/RNG/model/DriftModel.h>
#include <TestProject/RNG/model/CsvReplayModel.h>
#include <TestProject/RNG/model/BoolRandomModel.h>
#include <TestProject/RNG/model/BoolToggleModel.h>
#include <TestProject/RNG/model/BoolCsvModel.h>
#include <TestProject/RNG/model/ValvePairModel.h>
#include <TestProject/RNG/model/NotModel.h>
#include <TestProject/RNG/model/CompositeModel.h>
#include <fstream>
#include <stdexcept>
ModelRegistry& ModelRegistry::instance() {
static ModelRegistry reg;
return reg;
}
ModelRegistry::ModelRegistry() {
registerMode("constant", [](const json& p, float d) { return std::make_unique<ConstantModel>(p, d); });
registerMode("normal", [](const json& p, float d) { return std::make_unique<NormalModel>(p, d); });
registerMode("linear", [](const json& p, float d) { return std::make_unique<LinearModel>(p, d); });
registerMode("sine", [](const json& p, float d) { return std::make_unique<SineModel>(p, d); });
registerMode("uniform", [](const json& p, float d) { return std::make_unique<UniformModel>(p, d); });
registerMode("spike", [](const json& p, float d) { return std::make_unique<SpikeModel>(p, d); });
registerMode("drift", [](const json& p, float d) { return std::make_unique<DriftModel>(p, d); });
registerMode("csv", [](const json& p, float d) { return std::make_unique<CsvReplayModel>(p, d); });
registerMode("bool_random",[](const json& p, float d) { return std::make_unique<BoolRandomModel>(p, d); });
registerMode("bool_toggle",[](const json& p, float d) { return std::make_unique<BoolToggleModel>(p, d); });
registerMode("bool_csv", [](const json& p, float d) { return std::make_unique<BoolCsvModel>(p, d); });
registerMode("valve_pair", [](const json& p, float d) { return std::make_unique<ValvePairModel>(p, d); });
}
void ModelRegistry::registerMode(const std::string& mode, Ctor ctor) {
factory[mode] = std::move(ctor);
}
void ModelRegistry::loadModels(const std::string& jsonPath) {
std::ifstream f(jsonPath);
if (!f.is_open()) throw std::runtime_error("Cannot open " + jsonPath);
json j;
f >> j;
for (auto& [name, def] : j["models"].items()) {
modelTemplates[name] = { def["mode"].get<std::string>(), def.value("params", json::object()) };
}
}
std::unique_ptr<IModel> ModelRegistry::createModel(const std::string& modelName, float defaultVal) {
// Negation: !model
if (!modelName.empty() && modelName[0] == '!') {
return std::make_unique<NotModel>(createModel(modelName.substr(1), defaultVal));
}
// Composite: base+noise
auto plusPos = modelName.find('+');
if (plusPos != std::string::npos) {
auto base = createModel(modelName.substr(0, plusPos), defaultVal);
auto noise = createModel(modelName.substr(plusPos + 1), defaultVal);
return std::make_unique<CompositeModel>(std::move(base), std::move(noise));
}
// Inline CSV: csv:file:col
if (modelName.rfind("csv:", 0) == 0) {
auto first = modelName.find(':', 4);
auto second = modelName.find(':', first + 1);
json p;
p["file"] = modelName.substr(4, first - 4);
p["column"] = std::stoi(modelName.substr(first + 1, second - first - 1));
return factory["csv"](p, 0.0f);
}
// Valve pair reference: pair_model:action_model
auto colonPos = modelName.find(':');
if (colonPos != std::string::npos) {
std::string pairModel = modelName.substr(0, colonPos);
std::string actionModel = modelName.substr(colonPos + 1);
auto it = modelTemplates.find(pairModel);
if (it != modelTemplates.end() && it->second.mode == "valve_pair") {
auto model = factory["valve_pair"](it->second.params, defaultVal);
static_cast<ValvePairModel*>(model.get())->actionModelName = actionModel;
return model;
}
}
// Simple model name lookup
auto it = modelTemplates.find(modelName);
if (it == modelTemplates.end()) {
it = modelTemplates.find("normal_tiny");
}
auto fit = factory.find(it->second.mode);
if (fit == factory.end()) {
throw std::runtime_error("Unknown mode: " + it->second.mode);
}
return fit->second(it->second.params, defaultVal);
}
IModel* ModelRegistry::getOrCreate(const std::string& spec, float defaultVal,
const std::string& instanceKey) {
std::string key = instanceKey.empty() ? (spec.empty() ? "default" : spec) : instanceKey;
auto it = instances.find(key);
if (it != instances.end()) return it->second.get();
auto model = createModel(spec.empty() ? "normal_tiny" : spec, defaultVal);
// Track by model name (extract base name: strip composite/noise and pair suffixes)
std::string modelName = spec;
auto plusPos = modelName.find('+');
if (plusPos != std::string::npos) modelName = modelName.substr(0, plusPos);
auto colonPos = modelName.find(':');
if (colonPos != std::string::npos) modelName = modelName.substr(0, colonPos);
byModelName[modelName].push_back(model.get());
instances[key] = std::move(model);
return instances[key].get();
}
std::vector<IModel*> ModelRegistry::findByModelName(const std::string& modelName) {
auto it = byModelName.find(modelName);
if (it != byModelName.end()) return it->second;
return {};
}
// ValvePairModel::linkPeers defined here to avoid circular include
void ValvePairModel::linkPeers(ModelRegistry& reg) {
auto models = reg.findByModelName(actionModelName);
if (!models.empty()) action = models[0];
}