add pytorch classes uml diagram
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								docs/assets/freqai_pytorch-diagram.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/assets/freqai_pytorch-diagram.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 18 KiB | 
| @@ -282,10 +282,13 @@ In addition, the trainer is responsible for the following: | ||||
| #### Integration with Freqai module  | ||||
| Like all freqai models, PyTorch models inherit `IFreqaiModel`. `IFreqaiModel` declares three abstract methods: `train`, `fit`, and `predict`. we implement these methods in three levels of hierarchy. | ||||
| From top to bottom: | ||||
| 1. `BasePyTorchModel` - all `BasePyTorch*` inherit it. Implements the `train` method responsible for general data preparation (e.g., data normalization) and calling the `fit` method. Sets `device _type` attribute used by children classes. Sets `model_type` attribute used by the parent class. | ||||
| 2. `BasePyTorch*` - Here, the `*` represents a group of algorithms, such as classifiers or regressors. the `predict` method is responsible for data preprocessing, predicting, and postprocessing if needed. | ||||
|  | ||||
| 3. PyTorch*Classifier / PyTorch*Regressor - implements the `fit` method, responsible for the main train flaw, where we initialize the trainer and model objects. | ||||
|  | ||||
|  | ||||
| 1. `BasePyTorchModel` - Implements the `train` method. all `BasePyTorch*` inherit it. responsible for general data preparation (e.g., data normalization) and calling the `fit` method. Sets `device _type` attribute used by children classes. Sets `model_type` attribute used by the parent class. | ||||
| 2. `BasePyTorch*` -  Implements the `predict` method. Here, the `*` represents a group of algorithms, such as classifiers or regressors. responsible for data preprocessing, predicting, and postprocessing if needed. | ||||
|  | ||||
| 3. `PyTorch*Classifier` / `PyTorch*Regressor` - implements the `fit` method. responsible for the main train flaw, where we initialize the trainer and model objects. | ||||
|  | ||||
| #### Full example | ||||
| Building a PyTorch regressor using MLP (multilayer perceptron) model, MSELoss criterion, and AdamW optimizer. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user