Source code for modules.testAgents.testAgents

from logs import logDecorator as lD 
import json, pprint

import numpy               as np
import torch
import torch.nn.functional as F

from lib.agents import Agent_DQN as dqn
from lib.agents import qNetwork  as qN

from lib.envs    import envUnity
from lib.utils   import ReplayBuffer as RB


config = json.load(open('../config/config.json'))
logBase = config['logging']['logBase'] + '.modules.testAgents.testAgents'


[docs]@lD.log(logBase + '.testAllAgents') def testAllAgents(logger): '''print a line This function simply prints a single line Parameters ---------- logger : {logging.Logger} The logger used for logging error information ''' try: cfg = json.load(open('../config/modules/testAgents.json'))['params'] policy = lambda m: eval( cfg['agentParams']['randomAction'] ) memoryBuffer = RB.SimpleReplayBuffer(1000) QNslow = qN.qNetworkDiscrete( 37, 4, [50, 30, 10], activations=[F.tanh, F.tanh, F.tanh] ) QNfast = qN.qNetworkDiscrete( 37, 4, [50, 30, 10], activations=[F.tanh, F.tanh, F.tanh] ) with envUnity.Env(cfg['agentParams']['binaryFile'], showEnv=False) as env: agent = dqn.Agent_DQN(env, memoryBuffer, QNslow, QNfast, 4, 1) agent.eval() eps = 0.999 policy = lambda m: [agent.epsGreedyAction(m, eps)] print('Starting to generate memories ...') print('----------------------------------------') for _ in range(3): print('[Generating Memories] ', end='', flush=True) score = agent.memoryUpdateEpisode(policy, maxSteps=1000) print( 'Memory Buffer lengths: {}\nScore: {}'.format( agent.memory.shape, score ) ) print('Sampling from the memory:') memories = agent.memory.sample(20) s, a, r, ns, f = zip(*memories) s = np.array(s) print('Sampled some states of size {}'.format(s.shape)) print('Finding the maxAction ....') s = torch.as_tensor(s.astype(np.float32)) result1 = agent.randomAction(s) result2 = agent.maxAction(s) result3 = agent.epsGreedyAction(s, 0.5) print('Random Actioon stuff ......') results4 = env.episode(lambda m: [agent.randomAction(m)], 10)[0] # s, a, r, ns, f = zip(*results4) # print(s) print('Max Actioon stuff ......') results4 = env.episode(lambda m: [agent.maxAction(m)], 10) # print(len(results4)) print('epsGreedy Actioon stuff ......') results4 = env.episode(lambda m: [agent.epsGreedyAction(m, 1)], 10) print('Load/Save a model') agent.save('../models', 'someName') agent.load('../models', 'someName') print('Doing a soft update') agent.step(nSamples=100) agent.softUpdate(0.2) print('Finished a soft update') # agent.step(nSamples = 10) except Exception as e: logger.error(f'Unable to test all agents: {e}') return
[docs]@lD.log(logBase + '.main') def main(logger, resultsDict): '''main function for module1 This function finishes all the tasks for the main function. This is a way in which a particular module is going to be executed. Parameters ---------- logger : {logging.Logger} The logger used for logging error information resultsDict: {dict} A dintionary containing information about the command line arguments. These can be used for overwriting command line arguments as needed. ''' print('='*30) print('Main function of testAgents') print('='*30) # testAllAgents() print('Getting out of testAgents') print('-'*30) return