Martin Beierling-Mutz
MBM
/emem/

Use Machine Learning to filter messages in the browser

We'll use an Artifical Neural Network to classify messages as "spam" or "no spam".

The goal is to clean up this mess:

Twitch Chat Mess

This blog post was also published on dev.to

The Problem

“Should this message be hidden or shown?”

All solutions start with a problem statement. In our case, we have a text message and want to classify it as either "spam" or "no spam".

The classic application for this problem would be email, but I chose to use the Twitch chat of big channels. These tend to be really spammy and thus hard to read, while easy to get data.

What's Twitch?

Twitch.tv is an online platform that allows to stream video and reach thousands of like-minded people on the internet. It is most commonly used to stream video of a video game and a person playing the game. The chat is the main window of interaction with the streamer, but especially in channels with a lot of viewers the chat can become very hectic and unreadable by spam of messages that are of low value.

The Solution

Filtering text messages into spam and no-spam is a binary classification problem. There are lots of ways to solve this, and I chose one: Artificial Neural Networks (ANNs).

More specifically, it will be a feedforward neural net with backpropagation. This is a very common and basic setup, where our nodes won't form cycles (hence feedforward) and all nodes will be notified about how well they performed (hence backpropagation).

If you want to jump directly to code, here's the repo. Please take all of it with a grain of salt, as I created this 2 years ago as part of my Machine Learning Nanodegree at Udacity. Lots has changed since then, like the release of TensorFlow.js.

embiem/Better-Twitch-Chat

There's also a live version of the Web App, which lets you connect to a Twitch channel to either filter it or train your own model based on liking/disliking messages.

Done having a quick look? Great, let's continue...

What's an ANN?

AI, Machine Learning, Deep Learning and similar buzzwords are hot nowadays. So let's make sure the terminology is clear: Deep Learning is a subset of Machine Learning and AI is a researching field, which may use Machine Learning. I always like to give Boston Dynamics' Atlas robot as an example: It's a product of AI and robotics research, but doesn't actually use any Machine Learning solutions.

With ANN, we describe the general concept of modeling a digital system after biological neural networks. Our brain is built of lots of interconnected neurons, which create neural networks. We know how they work: neurons receive inputs and may send outputs if the input is above a threshold.

An artifical neural network consists of nodes imitating the behavior of neurons. Like Lego bricks, we assemble them in specific ways to form something cool, like a Millennium Falcon

Lego Millennium Falcon

Jokes aside, it is very astonishing to see how carefully built neural networks revolutionalize industries today.

In the end it all comes down to nodes. Connected in a specific way and with some logic attached, which defines how input is used to create an output.

How does an ANN work?

Let's take our feedforward neural net with backpropagation as an example. It should classify incoming text as either "spam" or "no spam". We will input our text message to the first nodes in our network. These nodes will do some calculations based on the received input and the nodes' internal state. The results are then sent to the next nodes. This happens until we reach the last node in our network, which will be responsible for classifying the input text message as "spam" or "no spam". During training, we know whether a message is spam or not and will give the network a 👍 or 👎, based on how well it did. This feedback is propagated back through all nodes of the network and every node will tune its internal state a bit.

There's a cool playground tool which let's you train a neural net right in your browser and play around with the different parts: playground.tensorflow.org. Don't worry too much about what a learning rate or TanH activation is yet. Just play around with it a bit. Have fun 🤓

Furthermore, while writing this article, @Petro Liashchynskyi published this article, which explains ANNs and their concepts on a more technical level.

I'll also have a list of awesome ML resources at the end of this article.

Data

The most important dependency of any Machine Learning solution is data. The more and better data you have, the better will your model perform.

Data Collection

This is often one of the most difficult tasks. Just imagine, labeling millions of images with "apple", "bike", "human", "dog", ...

In (supervised) Machine Learning, the machine learns by example. So we need to give the machine lots of examples. The more complex the task (like classifying objects in an image), the more examples we need.

For our problem, we need lots of text messages and label them as "spam" or "no spam". One of the reasons why I chose to filter messages of a Twitch channel: Most messages are small and collecting data is relatively easy.

In the Web App, there's one view dedicated for collecting data. You join a Twitch channel and all messages are listed in real-time, much like the built-in Twitch chat. Additionally, there's a like/dislike button for each message to indicate whether a message is spam or not. All liked/disliked messages are sent to a database.

Data Preprocessing

One of the main differences between machines and humans: They are very efficient at reading & computing zeroes and ones, while we are incredibly good at understanding concepts. We see a bunch of letters, read them and understand the information they keep. Machines wouldn't even know what letters are.

That's why we humans build character encodings like UTF-8, which gives the machine a way to structure 0s and 1s to form a concept of letters. Then we can do things like putting multiple characters into an array to build sentences or save articles like this one in a database.

To teach a machine anything about our world, we need to keep these things in mind. In this project we deal with text, so how do we input this text to a ML model?

There's a popular solution called Bag of Words (BOW), which takes text as input and outputs a bunch of zeroes and ones. Perfect!

For this project, you can find the data preprocessing in one small file at node/dataPrep.js. Let's also get some code into this article.

First we need to take our dataset and put all messages into an array and all labels (show = "no spam" & hide = "spam") into another array:

const dataFlat = [];
const dataLabels = [];

for (let key in data) {
  dataFlat.push(data[key].message);
  dataLabels.push(data[key].liked ? 'show' : 'hide');
}

Then we create a dictionary of words, using mimir:

const dictData = mimir.dict(dataFlat);
const dictLabels = mimir.dict(dataLabels);

And finally, we optimize our data a bit by removing any entries that only appeared once in the dataset:

for (let key in dictData.dict) {
  if (dictData.dict[key] < 2) {
    delete dictData.dict[key];
    _.remove(dictData.words, w => w === key);
    _.remove(dataFlat, w => w === key);
  }
}

This is optional, but will reduce the complexity that our model will need to handle. The bigger the BOW dictionary gets, the more complex the input space. Also called: The Curse of Dimensionality.

Data Exploration & Visualization

As data is very important to the success of you ML solution, knowing your data is part of it. Without knowing how the data is structured and finding certain characteristics, it will be hard to even start developing a model. This is especially important for datasets that you didn't create.

For our project, we know the data very well as we actually created the dataset on our own by liking/disliking messages. So one interesting visualization is the word occurrence. Here's a graph showing which words occurred the most in messages labeled as "show"/"no spam":

Word Occurrence Graph

Pretty common words, right? What's interesting is also how often these appear in messages labeled as "spam". This is interesting, because they appear very often in "no spam" messages but almost never in "spam" messages. Therefore, they have a high impact on the ML model performance when classifying messages that include these words.

This info could also come in handy when trying to reduce dimensionality. Certain approaches try to find high entropy features and combine these. But that's a story for another article.

Even if you don't know python, there's a very good notebook on Kaggle about data exploration: Comprehensive Data Exploration with Python.

For JavaScript devs, you might want to have a look at Observable. Just have a look at this data exploration JS notebook of the Titanic dataset. It's pretty cool!

Model

Here comes the easy part!

Once you know your data and what you want to achieve, it's time to create the model.

Build your model

I wasn't joking. In my opinion, data exploration actually takes more of your precious brain juice. Building and training your model "just" takes time and iteration.

You could look at research papers or a list of proven network architectures, but what fun would that be?

So let's build our model:

const net = new brain.NeuralNetwork({
  hiddenLayers: layers,
  learningRate: lr
});

I said it's easy!

This doesn't mean that it can't get complicated. It just doesn't have to. Especially with tools like brain.js, tf.js, keras etc., achieving a well performing model is possible with only a handful lines of code. This, of course, always depends on the problem you want to solve.

Model Training

This part can actually get a little more complex. My plan was to train multiple networks with varying architectures and parameters. So I wrote node/modelFactory.js, which encapsulates building and training of the model. This allows me to train multiple models with varying layers and learning rate.

The learning rate defines how quickly the model will learn. There are good defaults, but some models may work better with a higher LR and others with a much lower LR. This popular graphic shows what happens if the LR is too high or too low:

LR too high or low

Before you dive into the modelFactory.js code, here are the essential parts of model training.

First, you may need to transform your data, which is true for this project. We built dictionaries using mimir, now we need to get the actual Bag of Words. This will return one-hot encoded arrays (e.g. [0, 1, 0, 0, 1]). The important part is, that these arrays are always of the same length. Every 1 corresponds to a certain word of the dictionary and means that the message that this BOW represents contains the corresponding word. Check out mimir's BOW code example.

const allData = [];
dataFlat.forEach((entry, idx) => {
  allData.push({
    input: mimir.bow(entry, dictData),
    output: mimir.bow(dataLabels[idx], dictLabels)
  });
});

Then, we use our current layer and lr parameters to build the model and train it with our training data.

        // create the net
        const net = new brain.NeuralNetwork({
          hiddenLayers: layers,
          learningRate: lr
        });

        // train the net
        const trainResult = net.train(traindata, {
          errorThresh: 0.005,
          iterations: 10000,
          log: true,
          logPeriod: 100
        });

        // test the net
        const testResult = testing(net, testdata, threshold);

        // write net to file
        fs.writeFileSync(
          `./out/nets/${fold + 1}_${netName}.json`,
          JSON.stringify(net.toJSON())
        );

This will then save a trained model, which is ready to use.

Model Evaluation

Did you see the testing(net, testdata, threshold) call in the above code snippet? This will test how our model is performing after it was trained.

It calculates the precision, recall and finally the F1 score. This is a common score to use, and especially useful in our binary classification project.

The implementation is pretty straight forward:

function(net, testData, threshold) {
  const maxarg = array => {
    return array.indexOf(Math.max.apply(Math, array));
  };

  let truePositives = 0;
  let trueNegatives = 0;
  let falsePositives = 0;
  let falseNegatives = 0;

  for (let i = 0; i < testData.length; i++) {
    const result = net.run(testData[i].input);
    if (result.show >= threshold && testData[i].output.show === 1) {
      truePositives++;
    } else {
      falseNegatives++;
    }
  }

  const precision = truePositives / (truePositives + falsePositives);
  const recall = truePositives / (truePositives + falseNegatives);
  const f1 = 2 * ((precision * recall) / (precision + recall));
  return { precision, recall, f1 };
};

We take the trained net, some test data and a manually set threshold. The threshold is the break where we want to classify something as "spam" or "no spam". A reasonable value would be 0.8, which means that if the model is 80% sure a message is "no spam", then we will classify it as "no spam", otherwise "spam".

This score is then used to evaluate how well a certain model performs. You can see that I tested lots of different model architectures in node/hyperparamTuning.js.

Use the Model to classify messages

This part is a breeze with brain.js. The model can be represented as a JSON string. The same thing is true for our words dictionary using mimir. So all we need to do is load the model and word dictionary as string from some backend.

All the ML code in the React web app is located at src/api/NeuralNet.js. Essentially, to predict whether a new message is spam or not, we just need to call:

  predict(message) {
    const maxarg = array => {
      return array.indexOf(Math.max.apply(Math, array));
    };

    if (typeof message !== 'string' || message.length < 1) {
      console.warn(`Invalid message for prediction: ${message}`);
      return 0;
    }

    if (!this.net || !this.dict || typeof this.net.run !== 'function') {
      console.error('Cant predict because: net | dict', this.net, this.dict);
      return 0;
    }

    const test_bow_message = mimir.bow(message, this.dict);
    const prediction = this.net.run(test_bow_message);
    return maxarg(prediction);
  }

This will ...

  • create the BOW representation of the incoming message
  • call .run() on the neural net to get a prediction, which is an array like this: [0.2, 0.8]
  • return the index of the prediction array that has the highest value to classify the message as either "spam" or "no spam"

And that's it! We successfully trained a model offline using NodeJS, saved the best performing model as JSON and used it in the web app to predict a new message.

I left out a couple of things that are part of the web app. These include creation of the dataset, live data collection and in-browser training. You can find all these features in the repo and test them out in the web app. Let me know, if you want another article going more in-depth on certain features.

More to read & watch about ML

Afterword

I am by no means an expert on this topic, but I dipped my toe in the Machine Learning waters and would like to encourage more (web)devs to try it as well. Hopefully, this article helped some of you to do just that!

Also thanks to Ben Halpern, who encouraged me to write this. I really like the dev.to community, you folks rock! ❤️