Forked from
Tessaris Sergio / wumpus
5 commits behind, 7 commits ahead of the upstream repository.
-
Tessaris Sergio authoredTessaris Sergio authored
cli.py 3.52 KiB
#!/usr/bin/env python
"""
Command line interface
"""
import argparse
import io
import json
import os
import sys
from . import __version__
from .gridworld import GridWorld
from .runner import get_subclasses, check_entrypoint, get_player_class, get_world_class, run_episode, worlds
def gridrunner(*args):
"""
Run episodes on worlds using the specified player.
"""
world_classes = sorted(get_subclasses(GridWorld), key=lambda c: c.__name__)
parser = argparse.ArgumentParser(description=gridrunner.__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('infiles', type=argparse.FileType('r'), nargs='*', help='world description JSON files, they must be compatible with the world type (see --world option).')
parser.add_argument('--name', '-n', type=str, help='name of the player, default to the name of the player class')
parser.add_argument('--path', '-p', type=str, default='.', help="path of the player library, it's prepended to the sys.path variable")
parser.add_argument('--entry', '-e', type=check_entrypoint, required=True, help="object reference for a Player subclass in the form 'importable.module:object.attr'. See <https://packaging.python.org/specifications/entry-points/#data-model> for details.")
parser.add_argument('--world', '-w', type=str, default=world_classes[0].__name__, choices=[c.__name__ for c in world_classes], help='class name of the world')
parser.add_argument('--horizon', '-z', type=int, default=20, help='maximum number of steps')
parser.add_argument('--noshow', action='store_false', help="prevent the printing the world at each step")
parser.add_argument('--out', '-o', type=argparse.FileType('w'), default=sys.stdout, help="write output to file")
parser.add_argument('--version', action='version', version='%(prog)s ' + __version__)
parser.add_argument('--log', '-l', type=argparse.FileType('w'), help="write the log of the games to file (JSON)")
args_dict = vars(parser.parse_args(args))
name = args_dict['name']
path = os.path.abspath(args_dict['path']) if args_dict['path'] != '.' else os.getcwd()
obj_ref = args_dict['entry']
world_type = args_dict['world']
horizon = args_dict['horizon']
show = args_dict['noshow']
outf: io.TextIOBase = args_dict['out']
game_log: io.TextIOBase = args_dict['log']
player_class = get_player_class(obj_ref, path=path)
world_class = get_world_class(world_type)
if name is None:
name = player_class.__name__
player = player_class(name=name)
if game_log is not None:
print('[', file=game_log)
if len(args_dict['infiles']) > 0:
morelogs = False
for world in worlds(args_dict['infiles'], world_class):
glog = run_episode(world, player, horizon=horizon, show=show, outf=outf)
if game_log is not None:
if morelogs:
print(',', file=game_log)
else:
morelogs = True
json.dump(glog, game_log)
else:
world = world_class.random()
# show world definition
print('-' * 10 + ' Playing on world: ' + '-' * 10, file=outf)
world.to_JSON(outf)
print('\n' + '-' * 40, file=outf)
glog = run_episode(world, player, horizon=horizon, show=show, outf=outf)
if game_log is not None:
json.dump(glog, game_log)
if game_log is not None:
print(']', file=game_log)
return 0
def main():
sys.exit(gridrunner(*sys.argv[1:]))
if __name__ == "__main__":
main()