add pytorch classes uml diagram

This commit is contained in:
Yinon Polak 2023-03-23 12:13:27 +02:00
parent 36a005754a
commit fc8625c5c5
2 changed files with 6 additions and 3 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -282,10 +282,13 @@ In addition, the trainer is responsible for the following:
#### Integration with Freqai module
Like all freqai models, PyTorch models inherit `IFreqaiModel`. `IFreqaiModel` declares three abstract methods: `train`, `fit`, and `predict`. we implement these methods in three levels of hierarchy.
From top to bottom:
1. `BasePyTorchModel` - all `BasePyTorch*` inherit it. Implements the `train` method responsible for general data preparation (e.g., data normalization) and calling the `fit` method. Sets `device _type` attribute used by children classes. Sets `model_type` attribute used by the parent class.
2. `BasePyTorch*` - Here, the `*` represents a group of algorithms, such as classifiers or regressors. the `predict` method is responsible for data preprocessing, predicting, and postprocessing if needed.
3. PyTorch*Classifier / PyTorch*Regressor - implements the `fit` method, responsible for the main train flaw, where we initialize the trainer and model objects.
![image](assets/freqai_pytorch-diagram.png)
1. `BasePyTorchModel` - Implements the `train` method. all `BasePyTorch*` inherit it. responsible for general data preparation (e.g., data normalization) and calling the `fit` method. Sets `device _type` attribute used by children classes. Sets `model_type` attribute used by the parent class.
2. `BasePyTorch*` - Implements the `predict` method. Here, the `*` represents a group of algorithms, such as classifiers or regressors. responsible for data preprocessing, predicting, and postprocessing if needed.
3. `PyTorch*Classifier` / `PyTorch*Regressor` - implements the `fit` method. responsible for the main train flaw, where we initialize the trainer and model objects.
#### Full example
Building a PyTorch regressor using MLP (multilayer perceptron) model, MSELoss criterion, and AdamW optimizer.