Interactive neural network

Visualizing how brain cells communicate and learn

Rather than just watching a brain work, in the demos below you can create your own miniature brain! Simply click/press anywhere to add a new neuron to the brain. Press the up/down arrow or any number on your keyboard to change the layer your new neurons are made in. Just like in our brains, information is processed through different layers of neurons. For example, one layer in the visual part of our brain processes lines, the next layer combines lines together, the next layer figures out the shape of those lines, such as a box. In this visualization, information comes into layer 0 neurons by clicking/pressing the “Generate Input” box. This causes the voltage within these neurons to increase, making the neurons go from white (-70 mV) to a solid color as the voltage rises. Once the voltage is high enough (-50 mV), information is relayed to the next layer through an action potential (yellow balls). You can watch as information moves through a brain you create! More than just that, we can see here how our brains learn. You may notice that some of the connecting lines between neurons, called axons, get thicker or thinner as information flows through our mini-brain. This is called neuralplasticity. Neurons that fire with one another strengthen their connections (they become thicker) while neurons which fire out of sync with one another will get weaker (thinner) connections. When we learn something new, it is these connections between neurons which are getting thicker and helping us remember new information! Before you generate the input, can you guess which neurons will learn the fastest and wire together? 
You’ll have to refresh the page to make a new brain.

Here is my latest version from May 2019! Now you can make your own connections between neurons (axons) rather than have them be randomly generated. Simply click the two neurons you want to connect and Voilà, an axon is made. The first neuron you click will be pre-synaptic and the second will be post-synaptic.

And because of this, you are no longer constrained to my layers. Now there is only an input layer and hidden layers (to use computer-science-neural-network terms. However, this also means you can’t see the membrane potential of neurons. The old demo is still below if needed.




Old version: Click to make a new neuron. Use the arrows or number keys to change the layer. Once your brain is built, click “Generate Input” to get things going. Click individual neurons to see more about them.


Let me know in the comments if there are other features you would like to see in the program or of any errors you find.

All the code is available below. I wrote this in Processing, a open-source language built on Java. You can easily play around with the code here.
This neural network tool is inspired by Neuronify which is much more in-depth than mine (though requires a download).

Interested in programming and neural networks? Start here with the Coding Train. 

// An animated drawing of a Neural Network
// Based on Daniel Shiffman's example in "The Nature of Code", http://natureofcode.com

// Modified by Blake Porter
// www.blakeporterneuro.com
// Creates a network of neurons using a specified number of layers of neurons, inputs, and axons
// Connection synapse weights are created randomly and adjust based on activity levels, showing long term potentiation for active neural pairs and long term depression for inactive synapses

// Note - this code works well online in a browser with processing.js. It does not always work nicely natively in Processing. 
int yoffset = 30; // cause browsers are off by (-)30 for some reason

// for input button
int rectX, rectY;
int rectSize = 140;
color rectColor;
color rectHighlight;
boolean rectOver = false;
int inputFont = 28;

// Layer display
int layerFont = 50;

// for layer buttons
int rect2X, rect2YU, rect2YD;
int rect2Size = 75;
color rect2Color;
color rect2UpHighlight;
color rect2DownHighlight;
boolean rectUpOver = false;
boolean rectDownOver = false;
int layer2Font = rect2Size;
boolean clickedGen = false;


int neuronID = 0;
int axons = 1;         // the number of axons a neuron has
int currLayer = 0;
float randWMax = 0.001;
String input;
int baseDelay = 30;
int selectedNeuron;
int prevSelNeuron;
boolean logSelNeuron = false;
boolean overNothing = true;
//boolean skipN = false;
boolean wasGen = false;
boolean isGen = false;

// initialize 
int startRand = 0;
int endRand = 0;
int lastMsec = 0;
int currMsec = 0;


Network network;

void setup() {
  size(1200, 600); 
  rect2X = 10;
  rect2YU = 10;
  rect2YD = rect2YU+rect2Size+15;
  rect2Color = color(0);
  rect2UpHighlight = color(128);
  rect2DownHighlight = color(128);

  rectColor = color(0);
  rectHighlight = color(5, 178, 255);
  rectX = width/2-rectSize/2;
  rectY = rect2YU;


  // Create the Network object
  network = new Network(width/2, height/2);
}

void draw() {
  background(255);

  pushStyle();
  stroke(0);
  rect(rectX, rectY, rectSize, rectSize-40);
  String APprompt = "Generate Input";
  fill(rectColor);
  textSize(inputFont);
  textAlign(CENTER);
  stroke(0);
  fill(255);
  text(APprompt, rectX+4, rectY+20, rectSize, rectSize);
  popStyle();

  pushStyle();
  fill(rect2Color);
  rect(rect2X, rect2YU, rect2Size, rect2Size);
  rect(rect2X, rect2YD, rect2Size, rect2Size);
  fill(255);
  String arrowUp = "↑";
  String arrowDown = "↓";

  textSize(layer2Font);
  text(arrowUp, rect2X+(rect2Size/4), rect2Size-12);
  text(arrowDown, rect2X+(rect2Size/4), rect2YD+rect2Size-20);

  popStyle();

  // Current ayer info
  textSize(layerFont);
  String currLayerStr = str(currLayer);
  String layerText = "Layer: " + currLayerStr;
  fill(0, 0, 0);
  text(layerText, rect2Size+rect2X+10, layerFont);

  if (keyPressed) {
    int keyVal = key - 48;
    if (keyVal >= 0 && keyVal <= 9) {
      currLayer = keyVal;
    }
    if (currLayer < 0) {
      currLayer = 0;
    }
  }

  // You are doing a weird thing where update to check where mouse is is only done when the mouse is clicked
  // so even though this is not within the  mouseClicked event, rectOver only updated within mouseClicked
  // this was to solve the flashing problem

  if (isGen) {
    currMsec = millis();
    fill(rectHighlight);
    if (currMsec > lastMsec + baseDelay) {
      for (int i = 0; i < network.neurons.size(); i++) {
        Neuron currN = network.neurons.get(i);
        if (currN.layer == 0) {
          currN.feedForward(50);
          lastMsec = millis();
        }
      }
    }
  } else {
    fill(rectColor);
  }



  // Update and display the Network
  network.update();
  network.display();

  if (logSelNeuron && !overNothing) {
    pushStyle();
    textSize(12);
    fill(255);
    //noStroke();
    rect(network.neurons.get(selectedNeuron).location.x+network.neurons.get(selectedNeuron).r_base-1, network.neurons.get(selectedNeuron).location.y-12, 181, 15);
    fill(0);
    float currMP = map(network.neurons.get(selectedNeuron).sum, network.neurons.get(selectedNeuron).RMP, network.neurons.get(selectedNeuron).spkT, -70, -50);
    currMP = int(currMP);
    String MP = "Membrane Potential: " + currMP + "mV";
    text(MP, network.neurons.get(selectedNeuron).location.x+network.neurons.get(selectedNeuron).r_base+1, network.neurons.get(selectedNeuron).location.y);

    popStyle();
  }
} // end draw

void mouseClicked () { // change to mouseClicked for browsers

  //float x = mouseX + 0 - width/2;
  //float y = mouseY -yoffset - height/2;

  float x = mouseX + 0;
  float y = mouseY -yoffset;

  update(mouseX, mouseY-yoffset);

  //    println("__________");
  //  println("rectUpOver");
  //  println(rectUpOver);
  //  println("rectDownOver");
  //  println(rectDownOver);
  //  println("rectOver");
  //  println(rectOver);
  //  println("logSelNeuron");
  //  println(logSelNeuron);
  //  println("overNothing");
  //  println(overNothing);
  //  println("skipN");
  //  println(skipN);
  //println("isGen");
  //println(isGen);

  if (!overNothing) {
    if (rectUpOver) {
      currLayer = currLayer+1;
      currLayer = constrain(currLayer, 0, 9);
    }
    if (rectDownOver) {
      currLayer = currLayer -1;
      currLayer = constrain(currLayer, 0, 9);
    }
  //} else if (wasGen && !isGen) {
  //  wasGen = false;
  } else {
    Neuron n = new Neuron(x, y, neuronID, currLayer);
    neuronID = neuronID + 1;
    network.addNeuron(n);
    if (network.neurons.size() > 1 && currLayer != 0) {
      int[] possPost = {}; 
      for (int i = 0; i <= network.neurons.size()-1; i++) {
        Neuron postN = network.neurons.get(i);
        if (postN.layer == currLayer-1) {
          possPost = append(possPost, postN.neuronNum);
        }
      }


      if (possPost.length > 0) {
        float randPostf = random(0, possPost.length-1);
        int randPosti = round(randPostf);
        int randPostConnect = possPost[randPosti];

        Neuron randN = network.neurons.get(randPostConnect); 
        int newConnection = randN.neuronNum;
        n.addConnection(newConnection, random(0, randWMax));
        randN.addPostSyn(neuronID-1);

        float randPost2 = random(0, 1);
        if (randPost2 <= 0.5) {
          randPostf = random(0, possPost.length-1);
          randPosti = round(randPostf);
          randPostConnect = possPost[randPosti];

          Neuron randN2 = network.neurons.get(randPostConnect); 
          newConnection = randN2.neuronNum;
          n.addConnection(newConnection, random(0, randWMax));
          randN2.addPostSyn(neuronID-1);

          float randPost3 = random(0, 1);
          if (randPost3 <= 0.2) {
            randPostf = random(0, possPost.length-1);
            randPosti = round(randPostf);
            randPostConnect = possPost[randPosti];

            Neuron randN3 = network.neurons.get(randPostConnect); 
            newConnection = randN3.neuronNum;
            n.addConnection(newConnection, random(0, randWMax));
            randN3.addPostSyn(neuronID-1);
          }
        }
      }
    }
  }
}

void update(int x, int y) {
  overNothing = true;
  logSelNeuron = false;    
  rectOver = false;
  overNeuron();


  if ( overRect(rectX, rectY, rectSize, rectSize) && !clickedGen ) {
    overNothing = false;
    isGen = true;
    clickedGen = true;
    rectOver = true;
  } else if ( overRect(rectX, rectY, rectSize, rectSize) && clickedGen ) {
    overNothing = false;
    isGen = false;
    wasGen = true;
    clickedGen = false;
    rectOver = true;
  }


  if (overRect(rect2X, rect2YU, rect2Size, rect2Size)) {
    rectUpOver = true;
    overNothing = false;
  } else {
    rectUpOver = false;
  }
  if (overRect(rect2X, rect2YD, rect2Size, rect2Size)) {
    rectDownOver = true;
    overNothing = false;
  } else {
    rectDownOver = false;
  }
}


boolean overRect(int x, int y, int width, int height) {
  if (mouseX >= x && mouseX <= x+width && 
    mouseY-yoffset >= y && mouseY-yoffset <= y+height) {
    return true;
  } else {
    return false;
  }
}

void overNeuron() {
  for (int i = 0; i <= network.neurons.size()-1; i++) {
    Neuron curr = network.neurons.get(i);
    if (mouseX >= curr.location.x-curr.r_base && mouseX <= curr.location.x+curr.r_base 
      && mouseY-yoffset >= curr.location.y-curr.r_base && mouseY-yoffset <= curr.location.y+curr.r_base) {
      selectedNeuron = curr.neuronNum; 
      if (selectedNeuron == prevSelNeuron) {
        logSelNeuron = false;
        overNothing = false;
      } else {
        logSelNeuron = true;
        overNothing = false;
        prevSelNeuron = selectedNeuron;
      }
    }
  }
}

class ActionPotentials {
  PVector location;
  PVector receiver;
  int recNum;
  float weight;
  int lerpCounter;

  ActionPotentials(PVector loc, PVector rec, int recNumber, float w, int LC) {
    location = new PVector(loc.x, loc.y);
    receiver = new PVector(rec.x, rec.y);
    recNum = recNumber;
    weight = w;
    lerpCounter = LC;
  }
}

int baseCurrent = 150;
float normInput = 200;

int leak = 6;

float LTP = 1.1;
float LTD = 0.25;
int maxW = 5;

int updateTime = 500;
int prevUpdate = 0;
int currUpdate = 0;

class Network {
  // The Network has a list of neurons
  ArrayList<Neuron> neurons;

  // The Network now keeps a duplicate list of all Connection objects.
  // This makes it easier to draw everything in this class
  PVector location;

  Network(float x, float y) {
    location = new PVector(x, y);
    neurons = new ArrayList<Neuron>();
  }

  // We can add a Neuron
  void addNeuron(Neuron n) {
    neurons.add(n);
  }


  // Sending an input to the first layer of neurons
  void baseInput(int inputs) {
    for (int i = 0; i < inputs; i++) {
      Neuron n1 = neurons.get(i);
      n1.feedForward(baseCurrent);
    }
  }

  // Update connections
  void update() {
    if (network.neurons.size() > 1) {
      currUpdate = millis();
      if (currUpdate > prevUpdate + updateTime) {
        for (Neuron n : neurons) {
          n.sum = n.sum - leak;
          n.sum = constrain(n.sum, n.RMP, n.spkT);
          ArrayList currAPs = n.APs;
          int countAPs = n.APs.size();
          if (countAPs > 0) {

            for (int i = countAPs-1; i >= 0; i--) {
              ActionPotentials currAP = n.APs.get(i);
              if (currAP.location.x >= currAP.receiver.x-n.APr && currAP.location.y >= currAP.receiver.y-n.APr) {
                Neuron recN = network.neurons.get(currAP.recNum);
                float currW = currAP.weight;
                currW = constrain(currW, 1, maxW);
                recN.feedForward(normInput*currW);
                n.APs.remove(i);
              } else {
                currAP.location.x = lerp(currAP.location.x, currAP.receiver.x, 0.1);
                currAP.location.y = lerp(currAP.location.y, currAP.receiver.y, 0.1);
              }
            }
          }

          for (int i = 0; i < n.connections.length; i++) {
            Neuron preN = network.neurons.get(n.connections[i]);
            if (preN.layer == 0) {
              if (n.spkCount > 0) {
                n.weights[i] = n.weights[i] + LTP;
                //n.weights.add(i, LTP);
                float currW = n.weights[i];
                currW = constrain(currW, 0, maxW);
                n.weights[i] = currW;
              } else {
                n.weights[i] = n.weights[i] - LTD;
                //n.weights.sub(i, LTD);
                float currW = n.weights[i];
                currW = constrain(currW, 0, maxW);
                n.weights[i] = currW;
              } // spike count
            } else {
              if (n.spkCount >= preN.spkCount && n.spkCount > 0) {
                n.weights[i] = n.weights[i] + LTP;
                float currW = n.weights[i];
                currW = constrain(currW, 0, maxW);
                n.weights[i] = currW;
              } else {
                n.weights[i] = n.weights[i] - LTD;
                //n.weights.sub(i, LTD);
                float currW = n.weights[i];
                currW = constrain(currW, 0, maxW);
                n.weights[i] = currW;
              } // spike count
            }



            prevUpdate = millis();
          } // for all connections
          n.spkCount = 0;
        } // all neurons
      } else {
        for (Neuron n : neurons) {
          n.sum = n.sum - leak*0.2;
          n.sum = constrain(n.sum, n.RMP, n.spkT);
          ArrayList currAPs = n.APs;
          int countAPs = n.APs.size();
          if (countAPs > 0) {

            for (int i = countAPs-1; i >= 0; i--) {
              ActionPotentials currAP = n.APs.get(i);
              if (currAP.location.x >= currAP.receiver.x-n.APr && currAP.location.y >= currAP.receiver.y-n.APr) {
                Neuron recN = network.neurons.get(currAP.recNum);
                float currW = currAP.weight;
                currW = constrain(currW, 1, maxW);
                recN.feedForward(normInput*currW);
                n.APs.remove(i);
              } else {
                currAP.location.x = lerp(currAP.location.x, currAP.receiver.x, 0.1);
                currAP.location.y = lerp(currAP.location.y, currAP.receiver.y, 0.1);
              }
            }
          }
        }
      } // time
    }// size
  }// update

  // Draw everything
  void display() {
    //pushMatrix();
    // translate(location.x, location.y);
    for (Neuron n : neurons) {
      n.displayAx();
    }
    for (Neuron n : neurons) {
      n.displayN();
    }
    for (Neuron n : neurons) {
      n.displayAP();
    }
    // popMatrix();
  }
}

// An animated drawing of a Neural Network
// Based on Daniel Shiffman's example in "The Nature of Code", http://natureofcode.com

// Modified by Blake Porter
// www.blakeporterneuro.com
// Creates a network of neurons using a specified number of layers of neurons, inputs, and axons
// Connection synapse weights are created randomly and adjust based on activity levels, showing long term potentiation for active neural pairs and long term depression for inactive synapses

class Neuron {
  int neuronNum;
  PVector location;  // Neuron has a start location, x and y
  PVector centerLoc; // Neuron has a central location, x and y, in order to draw it's axon from
  int prevSpkCount = 0; // how many spikes did it have
  int spkCount = 0;  // how many times it has fired since we last checked?
  float sum = 0;     // the intracellular voltage value, start at 0
  int layer;
  boolean justFired = false;

  int refract = 150; // refractory period in ms
  int lastAP = 0;
  int RMP = 0;
  int spkT = 1000;

  // Neuron has a list of connections
  int[] connections = {};
  float[] weights = {};
  int[] postSyn = {};  

  ArrayList<ActionPotentials> APs = new ArrayList<ActionPotentials>();

  // The Neuron's size can be animated
  float r_base = 20;  // how big it is at the start
  float r_pop = 25;   // how big it grows to when it fires
  float r = r_base;   // for reference

  float APr = r_base*0.5;

  // Center of triangle of the neuron
  float centerX;
  float centerY;

  Neuron(float x, float y, int neuronID, int currLayer) {
    location = new PVector(x, y);  // get start position
    centerLoc = new PVector(location.x, location.y); //calculate center
    //centerLoc = new PVector(((location.x + (location.x+r) + (location.x+(r/2)))/3),(location.y + location.y + (location.y-r))/3); //calculate center
    layer = currLayer;
    neuronNum = neuronID;
  }

  // Add a Connection
  void addConnection(int c, float w) {
    connections = append(connections, c);
    weights = append(weights, w);

    //connections.append(c);
    //weights.append(w);
  } 

  void addPostSyn(int x) {
    postSyn = append(postSyn, x);
    //postSyn.append(x);
  }

  // Receive an input
  void feedForward(float input) {
    // Accumulate it

    sum += input;
    sum = constrain(sum, RMP, spkT);
    // did it reach the action potential threshold (of 1)?
    if (sum >= spkT && (millis() - lastAP) > refract) {
      fire();
      justFired = true;
      sum = 0;  // Reset the sum to 0 if it fires
      spkCount++; // add a spike to the spike count
      float apW = 0.0;
      lastAP = millis();
      for (int i = 0; i <= postSyn.length-1; i++) {
        int currConnect = postSyn[i];
        Neuron recN = network.neurons.get(currConnect);
        PVector rec = recN.location;
        for (int j = 0; j < recN.connections.length; j++) {
          int isThisConnected = recN.connections[j];
          if (isThisConnected == neuronNum) {
            apW = recN.weights[j];
          }
        }
        APs.add(new ActionPotentials(location, rec, recN.neuronNum, apW, 1));
      }
    } else if (sum >= spkT && (millis() - lastAP) < refract) {

      sum = 900;
    }
  }


  // The Neuron fires
  void fire() {
    r = r_pop;   // It suddenly is bigger
    // We send the output through all connections
  }


  void displayN() {
    // neurons
    pushStyle();
    stroke(0);
    strokeWeight(1);

    //float b = map(sum, 0, spkT, 0, 255);
    //if (justFired) {
    //  b = 255;
    //  justFired = false;
    //}
    //color fromC = color(11, 107, 191);
    //color toC = color(242, 223, 0);
    //float layerF = layer;
    //color neuronC = lerpColor(fromC, toC, layerF/9);
    //fill(neuronC, b);
    color fromC = color(255, 255, 255); 

    color startC = color(11, 107, 191);
    color endC = color(250, 218, 94);
    float layerF = layer;
    color toC = lerpColor(startC, endC, layerF/9);

    float memP = map(sum, 0, spkT, 0, 1);
    color neuronC = lerpColor(fromC, toC, memP);
    fill(neuronC);

    ellipse(location.x, location.y, r, r);
    r = lerp(r, r_base, 0.1);
    popStyle();
  } // display N

  void displayAx() {
    // Axons
    pushStyle();
    stroke(0, 0, 0);
    for (int i = 0; i < connections.length; i++) {
      int currPost = connections[i];
      Neuron post = network.neurons.get(currPost);
      float currW = weights[i];
      strokeWeight(currW);

      line(centerLoc.x, centerLoc.y, post.centerLoc.x, post.centerLoc.y);
    }
    popStyle();
  } // display axon

  void displayAP() {
    // AP
    pushStyle();
    for (int i = 0; i < APs.size(); i++) {
      ActionPotentials currAP = APs.get(i);

      //color fromAPC = color(242, 223, 0);
      //color toAPC = color(11, 107, 191);
      //float layerAPF = layer;
      //color apC = lerpColor(fromAPC, toAPC, layerAPF/9);
      stroke(100);
      strokeWeight(1);
      color apC = color(255, 255, 0);

      fill(apC);
      ellipse(currAP.location.x, currAP.location.y, APr, APr); // draw APs as X% the neuron size
    }
    popStyle();
  } // display ap
}

Creative Commons License
This work by Blake Porter is licensed under a Creative Commons Attribution-Non Commercial-ShareAlike 4.0 International License

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.