stable/freqtrade/freqai/prediction_models/RL/RLPrediction_agent_v2.py

225 lines
7.7 KiB
Python
Raw Normal View History

2022-08-13 15:48:58 +00:00
import torch as th
from torch import nn
from typing import Dict, List, Tuple, Type, Optional, Any, Union
import gym
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
CombinedExtractor
)
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3 import DQN
from stable_baselines3.common.policies import BasePolicy
#from stable_baselines3.common.policies import register_policy
from stable_baselines3.dqn.policies import (
QNetwork, DQNPolicy, MultiInputPolicy,
CnnPolicy, DQNPolicy, MlpPolicy)
import torch
def create_mlp_(
input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
) -> List[nn.Module]:
dropout = 0.2
if len(net_arch) > 0:
number_of_neural = net_arch[0]
modules = [
nn.Linear(input_dim, number_of_neural),
nn.BatchNorm1d(number_of_neural),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(number_of_neural, number_of_neural),
nn.BatchNorm1d(number_of_neural),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(number_of_neural, number_of_neural),
nn.BatchNorm1d(number_of_neural),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(number_of_neural, number_of_neural),
nn.BatchNorm1d(number_of_neural),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(number_of_neural, output_dim)
]
return modules
class TDQNetwork(QNetwork):
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True
):
super().__init__(
observation_space=observation_space,
action_space=action_space,
features_extractor=features_extractor,
features_dim=features_dim,
net_arch=net_arch,
activation_fn=activation_fn,
normalize_images=normalize_images
)
action_dim = self.action_space.n
q_net = create_mlp_(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net).apply(self.init_weights)
def init_weights(self, m):
if type(m) == nn.Linear:
torch.nn.init.kaiming_uniform_(m.weight)
class TDQNPolicy(DQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space=observation_space,
action_space=action_space,
lr_schedule=lr_schedule,
net_arch=net_arch,
activation_fn=activation_fn,
features_extractor_class=features_extractor_class,
features_extractor_kwargs=features_extractor_kwargs,
normalize_images=normalize_images,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs
)
@staticmethod
def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.kaiming_uniform_(module.weight)
if module.bias is not None:
module.bias.data.fill_(0.0)
def make_q_net(self) -> TDQNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return TDQNetwork(**net_args).to(self.device)
class TMultiInputPolicy(TDQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class TDQN(DQN):
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"TMultiInputPolicy": TMultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[TDQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 1,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class, # No action noise
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
target_update_interval=target_update_interval,
exploration_fraction=exploration_fraction,
exploration_initial_eps=exploration_initial_eps,
exploration_final_eps=exploration_final_eps,
max_grad_norm=max_grad_norm,
tensorboard_log=tensorboard_log,
create_eval_env=create_eval_env,
policy_kwargs=policy_kwargs,
verbose=verbose,
seed=seed,
device=device,
_init_setup_model=_init_setup_model
)
# try:
# register_policy("TMultiInputPolicy", TMultiInputPolicy)
# except:
# print("already registered")