Source

core/src/Agents/DQNAgent/DQNAgent.ts

import seedrandom from 'seedrandom';
import { Environment, EnvStateContext, Experience } from '../../index';
import * as tf from '@tensorflow/tfjs';
import { MathUtils, General } from '../../Utils';
import PersistableAgent from '../../RLInterface/PersistableAgent';
import * as FileStrategies from '../../RLInterface/FileStrategy';

/**
 * The settings of the DQN-Agent
 * @category Agents
 * @subcategory DQN
 * @property {number} learningRate The learning rate
 * @property {number} discountFactor The discount factor
 * @property {number[]} nnLayer The size of the neurlal network layer
 * @property {number} replayMemorySize The replay memory size
 * @property {number} batchSize The batch size
 * @property {number} replayMemoryInitSize The initial needed size of the replay memory
 * @property {number} epsilonStart The epsilon start value
 * @property {number} epsilonEnd The epsilon end value
 * @property {number} epsilonDecaySteps The number of epsilon decay steps
 * @property {?boolean} activateDoubleDQN Use a double DQN setup for learning (recommended)
 * @property {?number} updateTargetEvery After how many steps to synchronize the target network
 * with the local network when the double DQN is active
 * @property {?string} hiddenLayerActivation The hidden layer activation function to use. (recommended: 'relu').
 * See the  {@link https://js.tensorflow.org/api/latest/#layers.dense|tensorflow.js } documentation
 * for available activation functions.
 * @property {?boolean} layerNorm Whether to use layer normalization.
 * This slows down training but may improve training results.
 * @property {?number} kernelInitializerSeed The seed to use for kernel initialization.
 * This improved the reproducability of training results.
 */
export interface DQNAgentSettings {
    learningRate: number;
    discountFactor: number;
    nnLayer: number[];
    replayMemorySize: number;
    batchSize: number;
    replayMemoryInitSize: number;
    epsilonStart: number;
    epsilonEnd: number;
    epsilonDecaySteps: number;
    activateDoubleDQN?: boolean;
    updateTargetEvery?: number;
    hiddenLayerActivation?: string;
    layerNorm?: boolean;
    kernelInitializerSeed?: number;
}

/**
 * Interface for the DQN-Network
 * @category Agents
 * @subcategory DQN
 * @property {tf.Sequential} local The local network (see: {@link https://js.tensorflow.org/api/latest/#class:Sequential|tf.Sequential})
 * @property {?tf.Sequential} target The target network (see: {@link https://js.tensorflow.org/api/latest/#class:Sequential|tf.Sequential})
 */
export interface DQNNetwork {
    local: tf.Sequential;
    target?: tf.Sequential;
}

/**
 * Implementation of {@link https://arxiv.org/abs/1312.5602|DQN}
 * @category Agents
 * @extends PersistableAgent
 * @param {Environment} env The enviroment
 * @param {?DQNAgentSettings} config The configuration
 * @param {?number} randomSeed The configuration
 */
class DQNAgent extends PersistableAgent<tf.Sequential, DQNAgentSettings> {
    private _config?: DQNAgentSettings;
    private rng: seedrandom.PRNG;
    private experienceReplay: ReplayMemory;
    private randomSeed?: string;
    private qNetworkLocal: tf.Sequential;
    private qNetworkTarget: tf.Sequential;
    private epsilon: number;
    private epsilonStep: number = 0;
    private timeStep: number = 0;
    private loss: any;

    constructor(
        env: Environment,
        config?: DQNAgentSettings,
        randomSeed?: number
    ) {
        super(env);
        this.setRandomSeed(randomSeed);
        this._config = config;
    }

    public get config(): object | undefined {
        return this._config;
    }

    private setRandomSeed(randomSeed?: number) {
        if (randomSeed !== undefined) {
            this.randomSeed = randomSeed.toString();
            this.rng = seedrandom(this.randomSeed);
        } else {
            this.rng = seedrandom();
        }
    }

    /**
     * Get the network
     * @type {DQNNetwork}
     */
    public get network(): DQNNetwork {
        return {
            local: this.qNetworkLocal,
            target: this.qNetworkTarget,
        };
    }

    /**
     * Get the ReplayMemory
     * @type {ReplayMemory}
     */
    public get replayMemory(): ReplayMemory {
        return this.experienceReplay;
    }

    /**
     * Set The configuration of the agent after initailizing.
     * @param {?DQNAgentSettings} config The configuration.
     * @param {?number} randomSeed The random seed.
     */
    public setConfig(config?: DQNAgentSettings, randomSeed?: number): void {
        if (randomSeed !== undefined) this.setRandomSeed(randomSeed);
        if (config !== undefined) {
            this._config = config;
            this.epsilon = this._config.epsilonStart;
        }
        this.epsilonStep = 0;
    }

    public get trainingInitialized(): boolean {
        return (
            this._config !== undefined &&
            this.qNetworkLocal !== undefined &&
            ((this._config.activateDoubleDQN &&
                this.qNetworkTarget !== undefined) ||
                !this._config.activateDoubleDQN)
        );
    }

    public init(): void {
        if (!this.trainingInitialized) {
            this.reset();
        }
    }

    public reset(): void {
        if (this._config) {
            this.experienceReplay = new ReplayMemory(
                this._config.replayMemorySize
            );
        }
        // create local qNetwork
        this.qNetworkLocal = this.createNetwork();
        if (this._config) {
            this.epsilon = this._config.epsilonStart;

            if (this._config.activateDoubleDQN)
                this.qNetworkTarget = this.createNetwork();
        }
    }

    public step(state: object): string {
        return this.followEpsGreedyPolicy(state);
    }
    public async feed(
        prevState: object,
        takenAction: string,
        newState: object,
        payoff: number,
        contextInfo: EnvStateContext
    ): Promise<void> {
        this.experienceReplay.save(
            this.toExperience(
                prevState,
                takenAction,
                newState,
                payoff,
                contextInfo
            )
        );
        // wait untily replay memory is large enougth
        if (this.replayMemoryLargeEnougth()) {
            await this.train();
        }

        if (contextInfo.isTerminal || contextInfo.maxIterationReached) {
            this.decayEpsilon();
        }
    }
    private replayMemoryLargeEnougth() {
        return this.experienceReplay.size >= this._config!.replayMemoryInitSize;
    }

    public evalStep(state: object): string {
        const encodedState: tf.Tensor<tf.Rank> = tf.tensor(
            this.env.encodeStateToIndices(state),
            [1, this.env.stateDim.length]
        );
        const result: tf.Tensor<any> = this.qNetworkLocal.predict(
            encodedState
        ) as tf.Tensor<tf.Rank>;
        const qValues = result.arraySync() as number[][];
        const actionIdx = MathUtils.argMax(qValues[0]);
        return this.env.actionSpace[actionIdx];
    }
    public log(): void {
        console.log('epsilon', this.epsilon);
    }

    /**
     * Create a network
     * @returns {tf.Sequential}
     */
    public createNetwork(): tf.Sequential {
        const model = tf.sequential();

        const hiddenLayerAct = this._config?.hiddenLayerActivation
            ? this._config?.hiddenLayerActivation
            : 'relu';

        let kernelInitializer: any;
        if (this._config?.kernelInitializerSeed) {
            kernelInitializer = tf.initializers.heNormal({
                seed: this._config?.kernelInitializerSeed,
            });
        } else {
            kernelInitializer = tf.initializers.heNormal({});
        }

        // hidden layer
        model.add(
            tf.layers.dense({
                inputShape: [this.env.stateDim.length],
                activation: hiddenLayerAct as any,
                units: this._config!.nnLayer[0],
                kernelInitializer: kernelInitializer,
            })
        );
        if (this._config!.layerNorm) {
            model.add(
                tf.layers.layerNormalization({
                    center: true,
                    scale: true,
                })
            );
        }

        for (let i = 1; i < this._config!.nnLayer.length; i++) {
            model.add(
                tf.layers.dense({
                    units: this._config!.nnLayer[i],
                    activation: hiddenLayerAct as any,
                    kernelInitializer: kernelInitializer,
                })
            );
            if (this._config!.layerNorm) {
                model.add(
                    tf.layers.layerNormalization({
                        center: true,
                        scale: true,
                    })
                );
            }
        }

        // output layer
        model.add(
            tf.layers.dense({
                units: this.env.actionSpace.length,
                activation: 'linear',
            })
        );

        const adamOptimizer = tf.train.adam(this._config!.learningRate);

        model.compile({
            optimizer: adamOptimizer,
            loss: 'meanSquaredError',
            metrics: ['accuracy'],
        });
        model.summary();

        return model;
    }

    /**
     * Decay the epsilon value
     * @returns {void}
     */
    public decayEpsilon(): void {
        if (!this._config!.epsilonDecaySteps || !this._config!.epsilonEnd) {
            return;
        }
        const { epsilon, stepCount } = General.linearDecayEpsilon(
            this.epsilonStep,
            this._config!.epsilonDecaySteps,
            this._config!.epsilonStart,
            this._config!.epsilonEnd
        );

        this.epsilon = epsilon;
        this.epsilonStep = stepCount;
    }

    public async save(
        fileManager: FileStrategies.TFModelSaver<tf.Sequential>
    ): Promise<void> {
        await fileManager.save(this.qNetworkLocal);
    }

    public async load(
        fileManager: FileStrategies.TFModelLoader<tf.Sequential>
    ): Promise<void> {
        this.qNetworkLocal = await fileManager.load();

        const adamOptimizer = tf.train.adam(this._config!.learningRate);

        this.qNetworkLocal.compile({
            optimizer: adamOptimizer,
            loss: 'meanSquaredError',
            metrics: ['accuracy'],
        });
        this.qNetworkLocal.summary();

        //additionally load target network when needed
        if (this._config?.activateDoubleDQN) {
            this.qNetworkTarget = <tf.Sequential>await fileManager.load();

            const adamOptimizer = tf.train.adam(this._config.learningRate);

            this.qNetworkTarget.compile({
                optimizer: adamOptimizer,
                loss: 'meanSquaredError',
                metrics: ['accuracy'],
            });
            this.qNetworkTarget.summary();
        }
    }
    public async loadConfig(
        fileManager: FileStrategies.JSONLoader<DQNAgentSettings>
    ): Promise<void> {
        const loadObject: DQNAgentSettings = await fileManager.load();
        this.setConfig(loadObject);
    }
    public async saveConfig(
        fileManager: FileStrategies.JSONSaver<DQNAgentSettings>,
        options?: object
    ): Promise<void> {
        await fileManager.save(this._config!);
    }

    private async train(): Promise<void> {
        this.timeStep++;

        const miniBatch: BatchSample = this.experienceReplay.sample(
            this._config!.batchSize,
            this.rng
        );

        let targetNetwork: tf.Sequential;

        // use target network when in double DQN mode
        if (this._config!.activateDoubleDQN) {
            targetNetwork = this.qNetworkTarget;
        } else {
            targetNetwork = this.qNetworkLocal;
        }

        //get target prediction
        let target: number[][] = (
            targetNetwork.predict(
                tf.tensor(miniBatch.stateBatch)
            ) as tf.Tensor<tf.Rank>
        ).arraySync() as number[][];

        // get nextStateTarget prediction
        let targetNext: number[][] = (
            targetNetwork.predict(
                tf.tensor(miniBatch.newStateBatch)
            ) as tf.Tensor<tf.Rank>
        ).arraySync() as number[][];

        // update target according to algorithm
        for (let i = 0; i < this._config!.batchSize; i++) {
            if (miniBatch.contextInfoBatch[i].isTerminal) {
                target[i][miniBatch.actionBatch[i]] = miniBatch.payoffBatch[i];
            } else {
                const argMaxQ = Math.max(...targetNext[i]);
                target[i][miniBatch.actionBatch[i]] =
                    miniBatch.payoffBatch[i] +
                    this._config!.discountFactor * argMaxQ;
            }
        }

        let targetTensor = tf.tensor(target, [
            this._config!.batchSize,
            this.env.actionSpace.length,
        ]);
        let stateTensor = tf.tensor(miniBatch.stateBatch, [
            this._config!.batchSize,
            this.env.stateDim.length,
        ]);
        this.loss = await this.qNetworkLocal.fit(stateTensor, targetTensor, {
            batchSize: this._config!.batchSize,
            verbose: 0,
        });

        // update target network every "updateTargetEvery" steps
        if (
            this._config!.activateDoubleDQN &&
            this.timeStep >= this._config!.updateTargetEvery!
        ) {
            this.qNetworkTarget.setWeights(this.qNetworkLocal.getWeights());
            console.log('target weights updated');
            console.log('loss', this.loss);
            this.timeStep = 0;
        }
    }

    private toExperience(
        prevState: object,
        takenAction: string,
        newState: object,
        payoff: number,
        contextInfo: EnvStateContext
    ): Experience {
        return {
            prevState: this.env.encodeStateToIndices(prevState),
            takenAction: this.env.actionSpace.indexOf(takenAction),
            newState: this.env.encodeStateToIndices(newState),
            payoff: payoff,
            contextInfo: contextInfo,
        };
    }

    private sampleRandomAction(): string {
        const randIdx = Math.floor(this.rng() * this.env.actionSpace.length);
        return this.env.actionSpace[randIdx];
    }

    private followEpsGreedyPolicy(state: object): string {
        const randNum: number = this.rng();
        if (randNum < this.epsilon) {
            return this.sampleRandomAction();
        } else {
            return this.evalStep(state);
        }
    }
}

/**
 * A Batch sample
 * @category Agents
 * @subcategory DQN
 * @property {number[][]} stateBatch The states batch
 * @property {number[]} actionBatch The actions batch
 * @property {number[][]} newStateBatch The new states batch
 * @property {number[]} payoffBatch The payoffs batch
 * @property {EnvStateContext[]} contextInfoBatch The environment context info batch
 * */
export interface BatchSample {
    stateBatch: number[][];
    actionBatch: number[];
    newStateBatch: number[][];
    payoffBatch: number[];
    contextInfoBatch: EnvStateContext[];
}

/**
 * The Replay Memory
 * @category Agents
 * @subcategory DQN
 * @param {number} maxSize The maximal size of the replay memory
 */
export class ReplayMemory {
    private memory: Experience[];
    private _maxSize: number;

    constructor(maxSize: number) {
        this.memory = [];
        this._maxSize = maxSize;
    }

    /**
     * Get the max size
     * @type {number}
     */
    public get maxSize(): number {
        return this._maxSize;
    }

    /**
     * Get the current size
     * @type {number}
     */
    public get size(): number {
        return this.memory.length;
    }

    /**
     * Sample from memory
     * @param {number} batchSize The size of the batch to sample
     * @param {?seedrandom.PRNG} rng The random number generator to use for sampling
     * @return {BatchSample} The batch sample
     */
    public sample(batchSize: number, rng?: seedrandom.PRNG): BatchSample {
        let samples: Experience[] = General.sampleN(
            this.memory,
            batchSize,
            rng
        );
        return ReplayMemory.toBatch(samples);
    }

    private static toBatch(experiences: Experience[]): BatchSample {
        let stateBatch = new Array<number[]>(experiences.length);
        let takenActionBatch = new Array<number>(experiences.length);
        let newStateBatch = new Array<number[]>(experiences.length);
        let payoffBatch = new Array<number>(experiences.length);
        let contextInfoBatch = new Array<EnvStateContext>(experiences.length);

        for (let i = 0; i < experiences.length; i++) {
            stateBatch[i] = experiences[i].prevState;
            takenActionBatch[i] = experiences[i].takenAction;
            newStateBatch[i] = experiences[i].newState;
            payoffBatch[i] = experiences[i].payoff;
            contextInfoBatch[i] = experiences[i].contextInfo;
        }
        return {
            stateBatch: stateBatch,
            actionBatch: takenActionBatch,
            newStateBatch: newStateBatch,
            payoffBatch: payoffBatch,
            contextInfoBatch: contextInfoBatch,
        };
    }

    /**
     * Save an experience in the replay memory
     * @param {Experience} experience The experience to save
     * @returns {void}
     */
    public save(experience: Experience): void {
        const newLength: number = this.memory.push(experience);
        if (newLength > this._maxSize) {
            this.memory.shift();
        }
    }
}

export default DQNAgent;