alvinalexander.com | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (DeepLearning4jDistributed.java)

This example Java source code file (DeepLearning4jDistributed.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

actorref, address, class, cluster, deeplearning4jdistributed, exception, hazelcaststatetracker, illegalstateexception, net, network, override, reflection, runtimeexception, started, starting, string, threading, threads, util, workrouter

The DeepLearning4jDistributed.java Java example source code

/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.scaleout.actor.runner;

import akka.actor.*;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterClient;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.scaleout.actor.core.ClusterListener;
import org.deeplearning4j.scaleout.actor.core.ModelSaver;
import org.deeplearning4j.scaleout.actor.core.actor.BatchActor;
import org.deeplearning4j.scaleout.actor.core.actor.MasterActor;
import org.deeplearning4j.scaleout.actor.core.actor.ModelSavingActor;
import org.deeplearning4j.scaleout.actor.core.actor.WorkerActor;
import org.deeplearning4j.scaleout.actor.util.ActorRefUtils;
import org.deeplearning4j.scaleout.aggregator.INDArrayAggregator;
import org.deeplearning4j.scaleout.aggregator.JobAggregator;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.WorkerPerformerFactory;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.scaleout.workrouter.IterativeReduceWorkRouter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;

import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

/**
 * Controller for coordinating model training for a neural network based
 * on parameters across a cluster for akka.
 * @author Adam Gibson
 *
 */
public class DeepLearning4jDistributed implements DeepLearningConfigurable,Serializable {


    private static final long serialVersionUID = -4385335922485305364L;
    private transient ActorSystem system;
    private ActorRef mediator;
    private static final Logger log = LoggerFactory.getLogger(DeepLearning4jDistributed.class);
    private static String systemName = "ClusterSystem";
    private String type = "master";
    private Address masterAddress;
    private JobIterator iter;
    protected ActorRef masterActor;
    protected ModelSaver modelSaver;
    private transient ScheduledExecutorService exec;
    private transient StateTracker stateTracker;
    private int stateTrackerPort = -1;
    private String masterHost;
    private transient WorkRouter workRouter;



    /**
     * Master constructor
     * @param type the type (worker)
     * @param iter the dataset to use
     */
    public DeepLearning4jDistributed(String type, JobIterator iter) {
        this.type = type;
        this.iter = iter;
    }



    /**
     * Master constructor
     * @param iter the dataset to use
     */
    public DeepLearning4jDistributed(JobIterator iter,StateTracker stateTracker) {
        this("master",iter);
        this.stateTracker = stateTracker;
    }

    /**
     * Master constructor
     * @param iter the dataset to use
     */
    public DeepLearning4jDistributed(JobIterator iter) {
        this("master",iter);
    }



    /**
     * The worker constructor
     * @param type the type to use
     * @param address the address of the master
     */
    public DeepLearning4jDistributed(String type, String address) {
        this.type = type;
        URI u = URI.create(address);
        masterAddress = Address.apply(u.getScheme(), u.getUserInfo(), u.getHost(), u.getPort());
    }




    public DeepLearning4jDistributed() {
        super();
    }




    /**
     * Start a backend with the given role
     * @param joinAddress the join address
     * @param c the neural network configuration
     * @return the actor for this backend
     */
    public Address startBackend(Address joinAddress,Configuration c,JobIterator iter,StateTracker stateTracker) {

        ActorRefUtils.addShutDownForSystem(system);



        system.actorOf(Props.create(ClusterListener.class));

        try {
            Class<? extends WorkRouter> routerClazz =
                (Class<? extends WorkRouter>) Class.forName(c.get(WorkRouter.WORK_ROUTER, IterativeReduceWorkRouter.class.getName()));
            Constructor<?> constructor = routerClazz.getConstructor(StateTracker.class);
            workRouter = (WorkRouter) constructor.newInstance(stateTracker);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }


        workRouter.setup(c);

        ActorRef batchActor = system.actorOf(Props.create(BatchActor.class,iter,stateTracker,c,workRouter),"batch");

        log.info("Started batch actor");

        Props masterProps = Props.create(MasterActor.class,c,batchActor,stateTracker,workRouter);

		/*
		 * Starts a master: in the active state with the poison pill upon failure with the role of master
		 */
        final Address realJoinAddress = (joinAddress == null) ? Cluster.get(system).selfAddress() : joinAddress;

        c.set(MASTER_URL,realJoinAddress.toString());

        if(exec == null)
            exec = Executors.newScheduledThreadPool(2);


        Cluster cluster = Cluster.get(system);
        cluster.join(realJoinAddress);

        exec.schedule(new Runnable() {

            @Override
            public void run() {
                Cluster cluster = Cluster.get(system);
                cluster.publishCurrentClusterState();
            }

        }, 10, TimeUnit.SECONDS);

        masterActor = system.actorOf(
                ClusterSingletonManager.defaultProps(masterProps, "master", PoisonPill.getInstance(), "master"));

        log.info("Started master with address " + realJoinAddress.toString());
        c.set(MASTER_PATH,ActorRefUtils.absPath(masterActor, system));
        log.info("Set master abs path " + c.get(MASTER_PATH));

        return realJoinAddress;
    }


    @Override
    public void setup(final Configuration conf) {

        system = ActorSystem.create(systemName);
        ActorRefUtils.addShutDownForSystem(system);
        mediator = DistributedPubSubExtension.get(system).mediator();

        if(type.equals("master")) {

            if(iter == null)
                throw new IllegalStateException("Unable to initialize no dataset to iterate");

            log.info("Starting master");

            try {
                if(stateTracker == null) {
                    if(stateTrackerPort > 0)
                        stateTracker = new HazelCastStateTracker(stateTrackerPort);
                    else
                        stateTracker = new HazelCastStateTracker();
                }


                if(stateTracker.jobAggregator() == null) {
                    Class<? extends JobAggregator> clazz =
                        (Class<? extends JobAggregator>) Class.forName(conf.get(JobAggregator.AGGREGATOR, INDArrayAggregator.class.getName()));
                    JobAggregator agg = clazz.newInstance();
                    stateTracker.setJobAggregator(agg);
                }


                log.info("Started state tracker with connection string " + stateTracker.connectionString());



                masterAddress  = startBackend(null,conf,iter,stateTracker);

            } catch (Exception e1) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e1);
            }



            log.info("Starting Save saver");
            if(modelSaver == null)
                system.actorOf(Props.create(ModelSavingActor.class,"model-saver",stateTracker));
            else
                system.actorOf(Props.create(ModelSavingActor.class,modelSaver,stateTracker));



            //store it in zookeeper for service discovery
            conf.set(MASTER_URL,getMasterAddress().toString());
            conf.set(MASTER_PATH,ActorRefUtils.absPath(masterActor, system));

            //sets up the connection string for reference on the external worker
            conf.set(STATE_TRACKER_CONNECTION_STRING,stateTracker.connectionString());
            ActorRefUtils.registerConfWithZooKeeper(conf, system);


            system.scheduler().schedule(Duration.create(1, TimeUnit.MINUTES),
                    Duration.create(1, TimeUnit.MINUTES),
                    new Runnable() {

                        @Override
                        public void run() {
                            if (!system.isTerminated()) {
                                try {
                                    log.info("Current cluster members " +
                                            Cluster.get(system).readView().members());
                                } catch (Exception e) {
                                    log.warn("Tried reading cluster members during shutdown");
                                }
                            }

                        }

                    }, system.dispatcher());
        }

        else {

            log.info("Starting worker node");
            Address a = AddressFromURIString.parse(conf.get(MASTER_URL));

            Configuration c = new Configuration(conf);
            Cluster cluster = Cluster.get(system);
            cluster.join(a);

            try {
                String host = a.host().get();

                if(host == null)
                    throw new IllegalArgumentException("No host applyTransformToDestination for worker");


                String connectionString = conf.get(STATE_TRACKER_CONNECTION_STRING);
                //issue with setting the master url, fallback
                if(connectionString.contains("0.0.0.0")) {
                    if(masterHost == null)
                        throw new IllegalStateException("No master host specified and host discovery was lost due to" +
                            " improper setup on the master (related to hostname resolution) Please run the following" +
                            " command on your host: sudo hostname YOUR_HOST_NAME." +
                            " This will make your hostname resolution work correctly on master.");
                    connectionString = connectionString.replace("0.0.0.0",masterHost);
                }


                log.info("Creating state tracker with connection string "+  connectionString);
                if(stateTracker == null)
                    stateTracker = new HazelCastStateTracker(connectionString);

            } catch (Exception e1) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e1);
            }

            startWorker(c);

            system.scheduler().schedule(Duration.create(1, TimeUnit.MINUTES), Duration.create(1, TimeUnit.MINUTES), new Runnable() {

                @Override
                public void run() {
                    log.info("Current cluster members " + Cluster.get(system).readView().members());
                }

            },system.dispatcher());
            log.info("Setup worker nodes");
        }

        //only start dropwizard on the master
        if(type.equals("master")) {
            stateTracker.startRestApi();
        }

        else if(stateTracker instanceof HazelCastStateTracker)
            log.info("Not starting drop wizard; worker state detected");
    }


    public  void startWorker(Configuration conf) {

        Address contactAddress = AddressFromURIString.parse(conf.get(MASTER_URL));

        system.actorOf(Props.create(ClusterListener.class));
        log.info("Attempting to join node " + contactAddress);
        log.info("Starting workers");
        Set<ActorSelection> initialContacts = new HashSet<>();
        initialContacts.add(system.actorSelection(contactAddress + "/user/"));

        RoundRobinPool pool = new RoundRobinPool(Runtime.getRuntime().availableProcessors());

        ActorRef clusterClient = system.actorOf(ClusterClient.defaultProps(initialContacts),
                "clusterClient");


        try {
            String host = contactAddress.host().get();
            log.info("Connecting  to host " + host);
            int workers = stateTracker.numWorkers();
            if(workers <= 1)
                throw new IllegalStateException("Did not properly connect to cluster");


            log.info("Joining cluster of size " + workers);
            Class<? extends WorkerPerformerFactory> factoryClazz =
                (Class<? extends WorkerPerformerFactory>) Class.forName(conf.get(WorkerPerformerFactory.WORKER_PERFORMER));
            WorkerPerformerFactory factory = factoryClazz.newInstance();
            WorkerPerformer performer = factory.create(conf);

            Props p = pool.props(WorkerActor.propsFor(conf, stateTracker,performer));
            system.actorOf(p, "worker");

            Cluster cluster = Cluster.get(system);
            cluster.join(contactAddress);

            log.info("Worker joining cluster of " + stateTracker.workers().size());


        } catch (Exception e) {
            throw new RuntimeException(e);
        }




    }




    /**
     * Kicks off the distributed training.
     * It will grab the optimal batch size off of
     * the beginning of the dataset iterator which
     * is based on the desired mini batch size (conf.getSplit())
     *
     * and the number of initial workers in the state tracker after setup.
     *
     * For example, if you have a mini batch size of 10 and 8 workers
     *
     * the initial @link{JobIterator#next(int batches)} would be
     *
     * 80, this would be 10 per worker.
     */
    public void train() {
        log.info("Publishing to results for training");


        log.info("Started pipeline");
        //start the pipeline
        mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER,
                MoreWorkMessage.getInstance()), mediator);


        log.info("Published results");
        while(!stateTracker.isDone()) {
            log.info("State tracker not done...blocking");
            try {
                Thread.sleep(15000);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }

        shutdown();
    }




    public Address getMasterAddress() {
        return masterAddress;
    }

    public StateTracker getStateTracker() {
        return stateTracker;
    }

    public  void setStateTracker(
            StateTracker stateTracker) {
        this.stateTracker = stateTracker;
    }

    /**
     *
     * Shut down this network actor
     */
    public void shutdown() {
        //order matters here, the state tracker should
        try {
            system.shutdown();

        }catch(Exception e ) {
          // do nothing
        }
        try {
            if(stateTracker != null)
                stateTracker.shutdown();
        }catch(Exception e ) {
          // do nothing
        }

    }

    public ModelSaver getModelSaver() {
        return modelSaver;
    }

    /**
     * Sets a custom model saver. This will allow custom directories
     * among other things when saving snapshots.
     * @param modelSaver the model saver to use
     */
    public  void setModelSaver(ModelSaver modelSaver) {
        this.modelSaver = modelSaver;
    }

    /**
     * Gets the state tracker port.
     * A lot of state trackers will be servers
     * that need to be bound on a port.
     * This will allow overrides per implementation of the state tracker
     * @return the state tracker port that the state tracker
     * server will bind to
     */
    public  int getStateTrackerPort() {
        return stateTrackerPort;
    }

    public  void setStateTrackerPort(int stateTrackerPort) {
        this.stateTrackerPort = stateTrackerPort;
    }

    public String getMasterHost() {
        return masterHost;
    }

    public void setMasterHost(String masterHost) {
        this.masterHost = masterHost;
    }
}

Other Java examples (source code examples)

Here is a short list of links related to this Java DeepLearning4jDistributed.java source code file:

... this post is sponsored by my books ...

#1 New Release!

FP Best Seller

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.