Source

core/src/Agents/MCAgent/MCAgent.ts

import seedrandom from 'seedrandom';
import PersistableAgent from '../../RLInterface/PersistableAgent';
import { General, JSONTensor, MathUtils, Tensor } from '../../Utils';
import {
    EnvStateContext,
    Experience,
} from '../../RLInterface/SingleAgentEnvironment';
import Environment from '../../RLInterface/Environment';
import * as FileStrategies from '../../RLInterface/FileStrategy';

/**
 * Settings for the MCAgent
 * @category Agents
 * @subcategory MCAgent
 * @property {number} epsilonStart the epsilon start
 * @property {number} discountFactor the discount factor
 * @property {?number} epsilonEnd the epsilon end
 * @property {?number} epsilonDecaySteps the epsilon decay steps
 */
export interface MCAgentSettings {
    epsilonStart: number;
    discountFactor: number;
    epsilonEnd?: number;
    epsilonDecaySteps?: number;
}

/**
 * The Monte Carlo Save format
 * @category Agents
 * @subcategory MCAgent
 * @property {JSONTensor} valueTable the value table
 * @property {JSONTensor} stateReturnCountTable the state return count table
 */
export interface MCSaveFormat {
    valueTable: JSONTensor;
    stateReturnCountTable: JSONTensor;
}

/**
 * Implementation of First visit Monte Carlo
 * @category Agents
 * @extends PersistableAgent
 * @param {Environment} env The environment
 * @param {?MCAgentSettings} config The configuration
 * @param {?number} randomSeed the random seed
 */
class MCAgent extends PersistableAgent<MCSaveFormat, MCAgentSettings> {
    private _config?: MCAgentSettings;
    private rng: seedrandom.PRNG;
    private randomSeed?: string;
    private _valueTable: Tensor;
    private _stateReturnCountTable: Tensor;
    private _experience: Experience[] = [];

    private epsilon: number = 0;
    private epsilonStep: number = 0;

    constructor(
        env: Environment,
        config?: MCAgentSettings,
        randomSeed?: number
    ) {
        super(env);
        this.setRandomSeed(randomSeed);
        this._config = config;
    }

    public get config(): object | undefined {
        return this._config;
    }

    /**
     * Get the value table
     * @type {Tensor}
     */
    public get valueTable(): Tensor {
        return this._valueTable;
    }

    /**
     * Get the state retunn count table
     * @type {Tensor}
     */
    public get stateReturnCountTable(): Tensor {
        return this._stateReturnCountTable;
    }

    /**
     * Get a shallow copy of experiences
     * @type {Experience[]}
     */
    public get experience(): Experience[] {
        return this._experience.map((entry) => Object.assign({}, entry));
    }

    public get trainingInitialized(): boolean {
        return (
            this._valueTable !== undefined &&
            this._stateReturnCountTable !== undefined &&
            this._config !== undefined
        );
    }

    public init(): void {
        if (!this.trainingInitialized) {
            this.reset();
        }
    }

    public reset(): void {
        const valueTableDims: number[] = [...this.env.stateDim];
        valueTableDims.push(this.env.actionSpace.length);
        this._valueTable = Tensor.Zeros(valueTableDims);
        this._stateReturnCountTable = Tensor.Zeros(valueTableDims);
        this.setConfig(this._config);
    }

    /**
     * Set The configuration of the agent after initailizing.
     * @param {?MCAgentSettings} config The configuration
     * @param {?randomSeed} randomSeed The random seed
     */
    public setConfig(config?: MCAgentSettings, randomSeed?: number): void {
        if (randomSeed != undefined) this.setRandomSeed(randomSeed);
        if (config != undefined) {
            this._config = config;
            this.epsilon = this._config.epsilonStart;
        }
        this.epsilonStep = 0;
    }

    public step(state: object): string {
        return this.followEpsGreedyPolicy(state);
    }

    public async feed(
        prevState: object,
        takenAction: string,
        newState: object,
        payoff: number,
        envStateContext: EnvStateContext
    ): Promise<void> {
        // buffer experience
        this._experience.push({
            prevState: this.env.encodeStateToIndices(prevState),
            takenAction: this.env.actionSpace.indexOf(takenAction),
            newState: this.env.encodeStateToIndices(newState),
            payoff: payoff,
            contextInfo: envStateContext,
        });
        if (envStateContext.isTerminal) {
            // use experience for training and then reset experience
            this.mcTrainingStep();
        }
        if (envStateContext.maxIterationReached || envStateContext.isTerminal) {
            //empty experience
            this._experience = [];
        }
    }

    /**
     * Decay the epsilon value
     * @returns {void}
     */
    public decayEpsilon(): void {
        if (!this._config!.epsilonDecaySteps || !this._config!.epsilonEnd) {
            return;
        }

        const { epsilon, stepCount } = General.linearDecayEpsilon(
            this.epsilonStep,
            this._config!.epsilonDecaySteps,
            this._config!.epsilonStart,
            this._config!.epsilonEnd
        );

        this.epsilon = epsilon;
        this.epsilonStep = stepCount;
    }

    public evalStep(state: object): string {
        const actions: number[] = this.getStateActionValues(state);
        const actionIdx: number = MathUtils.argMax(actions);
        return this.env.actionSpace[actionIdx];
    }

    public log(): void {
        console.log('epsilon:', this.epsilon);
        console.log('epsilonStep', this.epsilonStep);
    }

    /**
     * Set the random Seed for the agent
     * @param {?number} randomSeed - the random seed
     */
    private setRandomSeed(randomSeed?: number) {
        if (randomSeed != undefined) {
            this.randomSeed = randomSeed.toString();
            this.rng = seedrandom(this.randomSeed);
        } else {
            this.rng = seedrandom();
        }
    }

    private mcTrainingStep() {
        let g: number = 0;
        let visitedExperiences: Experience[] = [];
        this.decayEpsilon();
        for (let i = this._experience.length - 1; i >= 0; i--) {
            const idxExperience: Experience = this._experience[i];
            g = g * this._config!.discountFactor + idxExperience.payoff;
            const alreadyVisited: boolean = this.stateAlreadyVisited(
                idxExperience.prevState,
                visitedExperiences
            );
            if (!alreadyVisited) {
                this.onAlreadyVisited(visitedExperiences, idxExperience, g);
            }
        }
    }

    private onAlreadyVisited(
        visitedExperiences: Experience[],
        idxExperience: Experience,
        g: number
    ) {
        visitedExperiences.push(idxExperience);
        let stateReturnCount: number = this._stateReturnCountTable.get(
            ...idxExperience.prevState,
            idxExperience.takenAction
        ) as number;
        stateReturnCount++;
        this._stateReturnCountTable.set(
            [...idxExperience.prevState, idxExperience.takenAction],
            stateReturnCount
        );
        const oldMean: number = this._valueTable.get(
            ...idxExperience.prevState,
            idxExperience.takenAction
        ) as number;
        const newMean =
            (oldMean / stateReturnCount) * (stateReturnCount - 1) +
            g / stateReturnCount;
        this._valueTable.set(
            [...idxExperience.prevState, idxExperience.takenAction],
            newMean
        );
    }

    private stateAlreadyVisited(
        state: number[],
        visitedExperiences: Experience[]
    ): boolean {
        const found = visitedExperiences.find((entry: Experience) => {
            return entry.prevState.every((val, idx) => val === state[idx]);
        });
        if (found !== undefined) return true;
        return false;
    }

    private getStateActionValues(state: object): number[] {
        const indices: number[] = this.env.encodeStateToIndices(state);
        return this._valueTable.get(...indices) as number[];
    }

    private followEpsGreedyPolicy(state: object): string {
        const randNum: number = this.rng();
        if (randNum < this.epsilon) {
            return this.sampleRandomAction();
        } else {
            return this.evalStep(state);
        }
    }

    private sampleRandomAction() {
        const randIdx = Math.floor(this.rng() * this.env.actionSpace.length);
        return this.env.actionSpace[randIdx];
    }

    public async load(
        fileManager: FileStrategies.JSONLoader<MCSaveFormat>
    ): Promise<void> {
        const loadObject: MCSaveFormat = await fileManager.load();
        this._valueTable = Tensor.fromJSONObject(loadObject.valueTable);
        this._stateReturnCountTable = Tensor.fromJSONObject(
            loadObject.stateReturnCountTable
        );
    }

    public async save(
        fileManager: FileStrategies.JSONSaver<MCSaveFormat>,
        options?: object
    ): Promise<void> {
        await fileManager.save({
            valueTable: this._valueTable.toJSONTensor(),
            stateReturnCountTable: this._stateReturnCountTable.toJSONTensor(),
        });
    }

    public async loadConfig(
        fileManager: FileStrategies.JSONLoader<MCAgentSettings>,
        options?: object
    ): Promise<void> {
        const loadObject: MCAgentSettings = await fileManager.load();
        this.setConfig(loadObject);
    }
    public async saveConfig(
        fileManager: FileStrategies.JSONSaver<MCAgentSettings>,
        options?: object
    ): Promise<void> {
        await fileManager.save(this._config!);
    }
}

export default MCAgent;