feat: add ModelRegistry with JSON loading and composite model
Add CompositeModel for combining base+noise models, ModelRegistry singleton with JSON-based model template loading, per-instance-key model isolation, and inline CSV/valve pair/composite syntax parsing in createModel.
This commit is contained in:
parent
4f83e41c0c
commit
b4bb27f1e5
15
TestProject/RNG/model/CompositeModel.h
Normal file
15
TestProject/RNG/model/CompositeModel.h
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <TestProject/RNG/model/IModel.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
struct CompositeModel : IModel {
|
||||||
|
std::unique_ptr<IModel> base;
|
||||||
|
std::unique_ptr<IModel> noise;
|
||||||
|
|
||||||
|
CompositeModel(std::unique_ptr<IModel> b, std::unique_ptr<IModel> n)
|
||||||
|
: base(std::move(b)), noise(std::move(n)) {}
|
||||||
|
|
||||||
|
float evaluate(size_t t) override {
|
||||||
|
return base->evaluate(t) + noise->evaluate(t);
|
||||||
|
}
|
||||||
|
};
|
||||||
126
TestProject/RNG/model/ModelRegistry.cc
Normal file
126
TestProject/RNG/model/ModelRegistry.cc
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
#include <TestProject/RNG/model/ModelRegistry.h>
|
||||||
|
#include <TestProject/RNG/model/ConstantModel.h>
|
||||||
|
#include <TestProject/RNG/model/NormalModel.h>
|
||||||
|
#include <TestProject/RNG/model/LinearModel.h>
|
||||||
|
#include <TestProject/RNG/model/SineModel.h>
|
||||||
|
#include <TestProject/RNG/model/UniformModel.h>
|
||||||
|
#include <TestProject/RNG/model/SpikeModel.h>
|
||||||
|
#include <TestProject/RNG/model/DriftModel.h>
|
||||||
|
#include <TestProject/RNG/model/CsvReplayModel.h>
|
||||||
|
#include <TestProject/RNG/model/BoolRandomModel.h>
|
||||||
|
#include <TestProject/RNG/model/BoolToggleModel.h>
|
||||||
|
#include <TestProject/RNG/model/BoolCsvModel.h>
|
||||||
|
#include <TestProject/RNG/model/ValvePairModel.h>
|
||||||
|
#include <TestProject/RNG/model/CompositeModel.h>
|
||||||
|
#include <fstream>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
ModelRegistry& ModelRegistry::instance() {
|
||||||
|
static ModelRegistry reg;
|
||||||
|
return reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelRegistry::ModelRegistry() {
|
||||||
|
registerMode("constant", [](const json& p, float d) { return std::make_unique<ConstantModel>(p, d); });
|
||||||
|
registerMode("normal", [](const json& p, float d) { return std::make_unique<NormalModel>(p, d); });
|
||||||
|
registerMode("linear", [](const json& p, float d) { return std::make_unique<LinearModel>(p, d); });
|
||||||
|
registerMode("sine", [](const json& p, float d) { return std::make_unique<SineModel>(p, d); });
|
||||||
|
registerMode("uniform", [](const json& p, float d) { return std::make_unique<UniformModel>(p, d); });
|
||||||
|
registerMode("spike", [](const json& p, float d) { return std::make_unique<SpikeModel>(p, d); });
|
||||||
|
registerMode("drift", [](const json& p, float d) { return std::make_unique<DriftModel>(p, d); });
|
||||||
|
registerMode("csv", [](const json& p, float d) { return std::make_unique<CsvReplayModel>(p, d); });
|
||||||
|
registerMode("bool_random",[](const json& p, float d) { return std::make_unique<BoolRandomModel>(p, d); });
|
||||||
|
registerMode("bool_toggle",[](const json& p, float d) { return std::make_unique<BoolToggleModel>(p, d); });
|
||||||
|
registerMode("bool_csv", [](const json& p, float d) { return std::make_unique<BoolCsvModel>(p, d); });
|
||||||
|
registerMode("valve_pair", [](const json& p, float d) { return std::make_unique<ValvePairModel>(p, d); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelRegistry::registerMode(const std::string& mode, Ctor ctor) {
|
||||||
|
factory[mode] = std::move(ctor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelRegistry::loadModels(const std::string& jsonPath) {
|
||||||
|
std::ifstream f(jsonPath);
|
||||||
|
if (!f.is_open()) throw std::runtime_error("Cannot open " + jsonPath);
|
||||||
|
json j;
|
||||||
|
f >> j;
|
||||||
|
for (auto& [name, def] : j["models"].items()) {
|
||||||
|
modelTemplates[name] = { def["mode"].get<std::string>(), def.value("params", json::object()) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<IModel> ModelRegistry::createModel(const std::string& modelName, float defaultVal) {
|
||||||
|
// Composite: base+noise
|
||||||
|
auto plusPos = modelName.find('+');
|
||||||
|
if (plusPos != std::string::npos) {
|
||||||
|
auto base = createModel(modelName.substr(0, plusPos), defaultVal);
|
||||||
|
auto noise = createModel(modelName.substr(plusPos + 1), defaultVal);
|
||||||
|
return std::make_unique<CompositeModel>(std::move(base), std::move(noise));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inline CSV: csv:file:col
|
||||||
|
if (modelName.rfind("csv:", 0) == 0) {
|
||||||
|
auto first = modelName.find(':', 4);
|
||||||
|
auto second = modelName.find(':', first + 1);
|
||||||
|
json p;
|
||||||
|
p["file"] = modelName.substr(4, first - 4);
|
||||||
|
p["column"] = std::stoi(modelName.substr(first + 1, second - first - 1));
|
||||||
|
return factory["csv"](p, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valve pair reference: pair_model:action_model
|
||||||
|
auto colonPos = modelName.find(':');
|
||||||
|
if (colonPos != std::string::npos) {
|
||||||
|
std::string pairModel = modelName.substr(0, colonPos);
|
||||||
|
std::string actionModel = modelName.substr(colonPos + 1);
|
||||||
|
auto it = modelTemplates.find(pairModel);
|
||||||
|
if (it != modelTemplates.end() && it->second.mode == "valve_pair") {
|
||||||
|
auto model = factory["valve_pair"](it->second.params, defaultVal);
|
||||||
|
static_cast<ValvePairModel*>(model.get())->actionModelName = actionModel;
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple model name lookup
|
||||||
|
auto it = modelTemplates.find(modelName);
|
||||||
|
if (it == modelTemplates.end()) {
|
||||||
|
it = modelTemplates.find("normal_tiny");
|
||||||
|
}
|
||||||
|
auto fit = factory.find(it->second.mode);
|
||||||
|
if (fit == factory.end()) {
|
||||||
|
throw std::runtime_error("Unknown mode: " + it->second.mode);
|
||||||
|
}
|
||||||
|
return fit->second(it->second.params, defaultVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
IModel* ModelRegistry::getOrCreate(const std::string& spec, float defaultVal,
|
||||||
|
const std::string& instanceKey) {
|
||||||
|
std::string key = instanceKey.empty() ? (spec.empty() ? "default" : spec) : instanceKey;
|
||||||
|
auto it = instances.find(key);
|
||||||
|
if (it != instances.end()) return it->second.get();
|
||||||
|
|
||||||
|
auto model = createModel(spec.empty() ? "normal_tiny" : spec, defaultVal);
|
||||||
|
|
||||||
|
// Track by model name (extract base name: strip composite/noise and pair suffixes)
|
||||||
|
std::string modelName = spec;
|
||||||
|
auto plusPos = modelName.find('+');
|
||||||
|
if (plusPos != std::string::npos) modelName = modelName.substr(0, plusPos);
|
||||||
|
auto colonPos = modelName.find(':');
|
||||||
|
if (colonPos != std::string::npos) modelName = modelName.substr(0, colonPos);
|
||||||
|
|
||||||
|
byModelName[modelName].push_back(model.get());
|
||||||
|
instances[key] = std::move(model);
|
||||||
|
return instances[key].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IModel*> ModelRegistry::findByModelName(const std::string& modelName) {
|
||||||
|
auto it = byModelName.find(modelName);
|
||||||
|
if (it != byModelName.end()) return it->second;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValvePairModel::linkPeers defined here to avoid circular include
|
||||||
|
void ValvePairModel::linkPeers(ModelRegistry& reg) {
|
||||||
|
auto models = reg.findByModelName(actionModelName);
|
||||||
|
if (!models.empty()) action = models[0];
|
||||||
|
}
|
||||||
38
TestProject/RNG/model/ModelRegistry.h
Normal file
38
TestProject/RNG/model/ModelRegistry.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <TestProject/RNG/model/IModel.h>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
class ModelRegistry {
|
||||||
|
public:
|
||||||
|
using Ctor = std::function<std::unique_ptr<IModel>(const json& params, float defaultVal)>;
|
||||||
|
|
||||||
|
static ModelRegistry& instance();
|
||||||
|
|
||||||
|
void loadModels(const std::string& jsonPath);
|
||||||
|
|
||||||
|
IModel* getOrCreate(const std::string& tables1Spec, float defaultVal,
|
||||||
|
const std::string& instanceKey = "");
|
||||||
|
|
||||||
|
std::vector<IModel*> findByModelName(const std::string& modelName);
|
||||||
|
|
||||||
|
void registerMode(const std::string& mode, Ctor ctor);
|
||||||
|
|
||||||
|
private:
|
||||||
|
ModelRegistry();
|
||||||
|
std::unique_ptr<IModel> createModel(const std::string& modelName, float defaultVal);
|
||||||
|
|
||||||
|
struct ModelDef {
|
||||||
|
std::string mode;
|
||||||
|
json params;
|
||||||
|
};
|
||||||
|
std::map<std::string, ModelDef> modelTemplates;
|
||||||
|
std::map<std::string, std::unique_ptr<IModel>> instances;
|
||||||
|
std::map<std::string, std::vector<IModel*>> byModelName;
|
||||||
|
std::map<std::string, Ctor> factory;
|
||||||
|
};
|
||||||
Loading…
Reference in New Issue
Block a user