On-the-Fly Machine Learning in the Browser with TensorFlow.js

By Dave Bitter

10 min read

TensorFlow.js is an incredibly powerful JavaScript library for training and deploying machine learning models in the browser and Node. js. Let’s explore this library by building a teachable machine!

On-the-Fly _Machine Learning_ in the Browser with TensorFlow.js
Authors

I first came across the teachable machine while researching TensorFlow.js. Quite frankly, the idea of this working seemed insane to me. Are you telling me I can run machine learning in the browser where I can train the model on the fly? You sure can! Luckily, Google provides a Codelab where you build a simple version yourself which I followed loosely.

The demo I'm going to build

I'm going to build a demo showing how you can apply transfer learning in the browser using TensorFlow.js in realtime:

As you can see, I take images for the three classes. One neutral, one with my phone in my hand and one with a coffee in my hand. Then when I click "Train & Predict" it trains in a fraction of a second. Finally, for each class, in real time you can see how confident it is that the class matches. Naturally, you can provide more diverse images so it gets better and better at classifying.

What is Tensorflow.js?

TensorFlow.js is a JavaScript library developed by Google that allows you to define, train, and run machine learning models entirely in the browser, using JavaScript and a high-level layers API. It is part of the TensorFlow ecosystem, which includes a range of tools for machine learning applications. The advantage of TensorFlow.js is that it allows machine learning models to be run in the browser (or in Node.js), making machine learning more accessible to JavaScript developers and allowing for real-time interaction with the user.

Some of the benefits are:

  • Privacy - you can both train and classify data on the machine of the user instead of sending it over to a server
  • Speed - you can directly classify data instead of having to send it over to a remote server
  • Sensor access - you can have access to the sensors of the user’s device like their camera and microphone
  • Deploy - you just have to deploy the web application and don’t have to mess with complex server-side setups for machine learning
  • Cost - last but not least, you dramatically cut down costs by just hosting a (static) website and using the user’s machine for the machine learning part

How does a teachable machine work?

A teachable machine basically takes an existing (or base) model and uses it on a similar but different domain. This is known as transfer learning. As humans, we do this all the time. We have a lifetime of experiences that we can use to recognize things we have never seen before. This example in the Codelab explains this best in my opinion. This is a willow tree:

A picture of a willow tree

There is a chance you might have never seen this type of tree before. Now that I have shown you, find the willow tree in this image:

A picture of a willow tree in a forest

You already have neurons in your brain that know how to identify objects that look like trees and long straight lines. You can use that knowledge to quickly classify the willow tree in this image as it is a tree, has long straight lines and you’ve learned that a willow tree looks like that.

Using the MobileNet base model

We need a model that is trained at classifying objects that we can then use to teach new things. Luckily this base model exists. MobileNet is a popular model that performs image recognition on 1000 different types of objects. It was trained on a huge dataset called ImageNet which has millions of labelled images. This model learned to spot common features among 1000 objects during training. Many of these features, like lines, textures, and shapes, can help identify new objects it hasn't seen before.

Let’s start building

For my demo, I bootstrapped a Next.js project and therefore the code examples might show React.js code. Naturally, this is just my preference for reactive web applications and you can use any (or no) framework.

Loading TensorFlow.js and MobileNet

First, I added TensorFlow.js through NPM with the @tensorflow/tfjs package. I then created a small hook where I load MobileNet:

export const useMobileNet = ({ tf, numberOfClasses }) => {
  const model = useRef(null)
  const mobileNet = useRef(null)
  const [readyToTrainAndPredict, setReadyToTrainAndPredict] = useState(false)

  useEffect(() => {
    if (!tf) {
      return
    }

    async function loadMobileNetFeatureModel() {
      // ...
    }
    loadMobileNetFeatureModel().then(() => {
      // ...
    })
  }, [tf])

  return {
    model: model.current,
    mobileNet: mobileNet.current,
    readyToTrainAndPredict,
  }
}

It receives two props. tf (TensorFlow.js) and numberOfClasses which is the number of different classes it should detect and categorise. Let’s dive into the loadMobileNetFeatureModel function:

async function loadMobileNetFeatureModel() {
  const URL =
    'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1'

  mobileNet.current = await tf.loadGraphModel(URL, {
    fromTFHub: true,
  })

  // Warm up the model by passing zeros through it once.
  tf.tidy(function () {
    mobileNet.current?.predict(tf.zeros([1, 224, 224, 3]))
  })
}

I load the model from TFHub which is why I need to let TensorFlow.js know through fromTFHub: true. To warm the model up, I pass zeros with a few specific values. 1 is the batch size, 244 is both the width and the height of the image and 3 is the number of colour channels (reg, green and blue). Once the MobileNet feature model is loaded I define the model head:

loadMobileNetFeatureModel().then(() => {
  model.current = tf.sequential()

  model.current.add(
    tf.layers.dense({
      inputShape: [1024],
      units: 128,
      activation: 'relu',
    })
  )

  model.current.add(
    tf.layers.dense({
      units: numberOfClasses,
      activation: 'softmax',
    })
  )

  model.current.summary()

  model.current.compile({
    optimizer: 'adam',
    loss: numberOfClasses === 2 ? 'binaryCrossentropy' : 'categoricalCrossentropy',
    metrics: ['accuracy'],
  })

  setReadyToTrainAndPredict(true)
})

Quite a few things are happening here. Let’s go over them. First, I'm setting up a model that learns patterns in data to make predictions. Then, I'm adding an input layer, which acts as a first step in understanding the data. I'm using 128 "neurons" to analyze the data, and I'm applying a function called ReLU to understand it better. Next, I'm creating the output layer where the model makes its predictions. It looks at the patterns it learned and makes guesses based on them. Here, I'm using a function called softmax to ensure it selects one option from all the possibilities. After that, I'm checking what the model looks like so far by generating a model summary. Finally, before the model starts learning, I'm setting up how it should learn. I'm instructing it to improve its guesses over time and defining how it can evaluate its performance, such as by checking its accuracy.

Using the user’s webcam to gather training data

Next, I want to be able to gather images to use in the demo. I could just let the user upload images, but let’s make it easier and allow them to snap images with their webcam to use. Yet again, I create a small hook to give us access:

export const useWebcam = () => {
  const videoRef = useRef(null)
  const [videoPlaying, setVideoPlaying] = useState(false)

  const toggleWebcam = async () => {
    if (videoPlaying) {
      videoRef.current.srcObject?.getVideoTracks().forEach((track) => {
        track.stop()
      })
      setVideoPlaying(false)
      return
    }

    if (!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia)) {
      // getUsermedia parameters.
      const constraints = {
        video: true,
        width: 640,
        height: 480,
      }

      // Activate the webcam stream.
      navigator.mediaDevices.getUserMedia(constraints).then(function (stream) {
        videoRef.current.srcObject = stream
        videoRef.current.addEventListener('loadeddata', function () {
          setVideoPlaying(true)
        })
      })
    } else {
      console.warn('getUserMedia() is not supported by your browser')
    }
  }

  useEffect(() => {
    navigator.permissions.query({ name: 'camera' }).then(function (permissionStatus) {
      if (permissionStatus.state === 'granted') {
        toggleWebcam()
      }
    })
  }, [])

  return {
    toggleWebcam,
    videoRef,
    videoPlaying,
  }
}

With this hook, I can toggle the webcam on and off and have access to the video reference so I place it on an HTML video element:

const { toggleWebcam, videoPlaying, videoRef } = useWebcam()

// ...

;<video ref={videoRef} />

Next, for each class, we add a button to the UI to capture data. Once the user clicks this button, I call the gatherDataForClass function:

const [capturedImages, setCapturedImages] = useState({})

const [isTraining, setIsTraining] = useState(false)
const [trainingData, setTrainingData] = useState([])

const gatherDataForClass = (classNumber) => {
  let imageFeatures = tf?.tidy(function () {
    let videoFrameAsTensor = tf.browser.fromPixels(videoRef.current)
    let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [224, 224], true)
    let normalizedTensorFrame = resizedTensorFrame.div(255)
    return mobileNet.predict(normalizedTensorFrame.expandDims()).squeeze()
  })

  setTrainingData([...trainingData, { input: imageFeatures, output: classNumber }])

  const canvas = document.createElement('canvas')
  canvas.width = videoRef.current.videoWidth
  canvas.height = videoRef.current.videoHeight
  canvas.getContext('2d').drawImage(videoRef.current, 0, 0)
  const capturedImage = canvas.toDataURL('image/png')
  setCapturedImages((prevImages) => {
    const classKey = `class_${classNumber + 1}`
    return {
      ...prevImages,
      [classKey]: [capturedImage, ...prevImages[classKey]],
    }
  })
}

Using tf.browser.fromPixels() I grab a picture from the webcam feed. Next, I resize the videoFrameAsTensor variable to be of the correct shape for the MobileNet model's input. Remember, in loadMobileNetFeatureModel we set this to 224 by 224. I store this training data for this class in an array for later use.

It might also be nice to show the user which images are being used in which class. I use canvas to create a data URL for the image so I can show it in the UI.

Training the model on-the-fly

Once the user has captured data (images) for the classes, it’s time to train the model with this new data.

const [progressBarValues, setProgressBarValues] = useState()
const shouldPredict = useRef(false)

function predictLoop() {
  // ...
}

const trainAndPredict = async () => {
  setIsTraining(true)
  shouldPredict.current = false

  const trainingDataInputs = trainingData.map((data) => data.input)
  const trainingDataOutputs = trainingData.map((data) => data.output)

  tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs)
  let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32')
  let oneHotOutputs = tf.oneHot(outputsAsTensor, 3)
  let inputsAsTensor = tf.stack(trainingDataInputs)

  await model.fit(inputsAsTensor, oneHotOutputs, {
    shuffle: true,
    batchSize: 5,
    epochs: 10,
  })

  outputsAsTensor.dispose()
  oneHotOutputs.dispose()
  inputsAsTensor.dispose()

  setIsTraining(false)
  shouldPredict.current = true
  predictLoop()
}

You can call the trainAndPredict function that will take the training data we stored before and pass it to the model. First, we shuffle the training data to ensure that the order of the data does not cause any issues during training. Next, I convert the outputs to tensors. I can then pass them to the tf.oneHot() function along with the max number of classes which in the case of my demo is 3. Next, I convert the input tensors to become regular 2D tensors using the tf.stack() function. I can now finally train the model head using the model.fit() function where I pass the tensors. Finally, I can dispose of the created tensors as the model is trained and I don’t need them anymore.

Time to call the predictLoop function that will continuously grab a frame of the webcam, and predict in which class it falls:

function predictLoop() {
  tf?.tidy(function () {
    if (!shouldPredict.current) {
      return
    }

    let videoFrameAsTensor = tf.browser.fromPixels(videoRef.current).div(255)
    let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [224, 224], true)

    let imageFeatures = mobileNet.predict(resizedTensorFrame.expandDims())
    let prediction = model.predict(imageFeatures).squeeze()
    let predictionArray = prediction.arraySync()

    setProgressBarValues(predictionArray)
  })

  window.requestAnimationFrame(predictLoop)
}

As this is a recursive function using window.requestAnimationFrame we first check if we need to break out of the loop. Next we grab the current frame of the webcam similarly to how we previously grabbed images to gather data for the classes. Next, we essentially call the predict method on the model to get a value for each class from 0 to 1 on how confident it is that the class matches. The results might look something like [0.39075496792793274, 0.6091347932815552, 0.00011020545935025439] indicating that this is most likely the second class with hints of the first class. Finally, we take these values and store them so we can display progress bars under each class to visualise how certain it is that the current webcam frame is for that class.

The final result

The above logic, combined with some extra application logic results in the demo you can try out here (probably best on desktop). The final code can be found over at my GitHub.

Some things I noticed

While building this demo, I noticed a few things. Firstly, this stuff is hard! I don’t have a machine learning background so many of the techniques and concepts I was not familiar with. While doing the Codelab I read things like "Now it's time to define your model head, which is essentially a very minimal multi-layer perceptron." which didn’t fill me with much confidence that I was able to build this demo. Luckily, with the help of the Codelab and TensorFlow.js’s documentation of their utilities to make this stuff easier for you, I was able to make something cool!

I also noticed how accessible this actually is for developers. Being able to run this in the browser on a relatively "low-powered" machine like my laptop (in machine learning terms) is just amazing. I’m able to share a link to this static website and let people play around with this and train models which is amazing!

Finally, I noticed that this is a gateway into machine learning for me. It hits the right balance between having familiar concepts and being a bit out of my depth and wanting to learn about these techniques in this exciting field. Try one of these Codelabs out. It might be a lot less daunting than you suspect.

Cool demo, what are actual use cases?

As the demo and possibilities are so wide, it’s hard to see actual use cases at first. Just focussing on image recognition and transfer learning (which is a tiny part of what machine learning can offer) I can think of quite a few use cases:

  • Vehicle Damage Assessment in Insurance Apps - Users can upload photos of damaged vehicles, and the app can classify the type and severity of damage, facilitating the claims process for insurance companies.
  • Vehicle Condition Assessment for Lease Returns - Leaseholders can upload photos of leased vehicles for inspection before returning them. The app can classify the condition of the vehicle, helping to determine if any additional charges are warranted.
  • Property Identification for Mortgage Assessment - Analyze uploaded photos of properties to classify them into residential, commercial, or mixed-use categories, assisting in property valuation and mortgage assessment.

You can kind of see where I’m going with this. Even scoped to this tiny part of what is possible with machine learning, it’s an incredibly useful tool we can offer to our users. in the end, it’s a tool that supplies us with data. We can then use that data to offer awesome capabilities and user experiences!


Upcoming events

  • Coven of Wisdom - Herentals - Winter `24 edition

    Worstelen jij en je team met automated testing en performance? Kom naar onze meetup waar ervaren sprekers hun inzichten en ervaringen delen over het bouwen van robuuste en efficiënte applicaties. Schrijf je in voor een avond vol kennis, heerlijk eten en een mix van creativiteit en technologie! 🚀 18:00 – 🚪 Deuren open 18:15 – 🍕 Food & drinks 19:00 – 📢 Talk 1 20:00 – 🍹 Kleine pauze 20:15 – 📢 Talk 2 21:00 – 🙋‍♀️ Drinks 22:00 – 🍻 Tot de volgende keer? Tijdens deze meetup gaan we dieper in op automated testing en performance. Onze sprekers delen heel wat praktische inzichten en ervaringen. Ze vertellen je hoe je effectieve geautomatiseerde tests kunt schrijven en onderhouden, en hoe je de prestaties van je applicatie kunt optimaliseren. Houd onze updates in de gaten voor meer informatie over de sprekers en hun specifieke onderwerpen. Over iO Wij zijn iO: een groeiend team van experts die end-to-end-diensten aanbieden voor communicatie en digitale transformatie. We denken groot en werken lokaal. Aan strategie, creatie, content, marketing en technologie. In nauwe samenwerking met onze klanten om hun merken te versterken, hun digitale systemen te verbeteren en hun toekomstbestendige groei veilig te stellen. We helpen klanten niet alleen hun zakelijke doelen te bereiken. Samen verkennen en benutten we de eindeloze mogelijkheden die markten in constante verandering bieden. De springplank voor die visie is talent. Onze campus is onze broedplaats voor innovatie, die een omgeving creëert die talent de ruimte en stimulans geeft die het nodig heeft om te ontkiemen, te ontwikkelen en te floreren. Want werken aan de infinite opportunities van morgen, dat doen we vandaag.

    | Coven of Wisdom Herentals

    Go to page for Coven of Wisdom - Herentals - Winter `24 edition
  • Mastering Event-Driven Design

    PLEASE RSVP SO THAT WE KNOW HOW MUCH FOOD WE WILL NEED Are you and your team struggling with event-driven microservices? Join us for a meetup with Mehmet Akif Tütüncü, a senior software engineer, who has given multiple great talks so far and Allard Buijze founder of CTO and founder of AxonIQ, who built the fundaments of the Axon Framework. RSVP for an evening of learning, delicious food, and the fusion of creativity and tech! 🚀 18:00 – 🚪 Doors open to the public 18:15 – 🍕 Let’s eat 19:00 – 📢 Getting Your Axe On Event Sourcing with Axon Framework 20:00 – 🍹 Small break 20:15 – 📢 Event-Driven Microservices - Beyond the Fairy Tale 21:00 – 🙋‍♀️ drinks 22:00 – 🍻 See you next time? Details: Getting Your Axe On - Event Sourcing with Axon Framework In this presentation, we will explore the basics of event-driven architecture using Axon Framework. We'll start by explaining key concepts such as Event Sourcing and Command Query Responsibility Segregation (CQRS), and how they can improve the scalability and maintainability of modern applications. You will learn what Axon Framework is, how it simplifies implementing these patterns, and see hands-on examples of setting up a project with Axon Framework and Spring Boot. Whether you are new to these concepts or looking to understand them more, this session will provide practical insights and tools to help you build resilient and efficient applications. Event-Driven Microservices - Beyond the Fairy Tale Our applications need to be faster, better, bigger, smarter, and more enjoyable to meet our demanding end-users needs. In recent years, the way we build, run, and operate our software has changed significantly. We use scalable platforms to deploy and manage our applications. Instead of big monolithic deployment applications, we now deploy small, functionally consistent components as microservices. Problem. Solved. Right? Unfortunately, for most of us, microservices, and especially their event-driven variants, do not deliver on the beautiful, fairy-tale-like promises that surround them.In this session, Allard will share a different take on microservices. We will see that not much has changed in how we build software, which is why so many “microservices projects” fail nowadays. What lessons can we learn from concepts like DDD, CQRS, and Event Sourcing to help manage the complexity of our systems? He will also show how message-driven communication allows us to focus on finding the boundaries of functionally cohesive components, which we can evolve into microservices should the need arise.

    | Coven of Wisdom - Utrecht

    Go to page for Mastering Event-Driven Design
  • The Leadership Meetup

    PLEASE RSVP SO THAT WE KNOW HOW MUCH FOOD WE WILL NEED What distinguishes a software developer from a software team lead? As a team leader, you are responsible for people, their performance, and motivation. Your output is the output of your team. Whether you are a front-end or back-end developer, or any other discipline that wants to grow into the role of a tech lead, RSVP for an evening of learning, delicious food, and the fusion of leadership and tech! 🚀 18:00 – 🚪 Doors open to the public 18:15 – 🍕 Let’s eat 19:00 – 📢 First round of Talks 19:45 – 🍹 Small break 20:00 – 📢 Second round of Talks 20:45 – 🙋‍♀️ drinks 21:00 – 🍻 See you next time? First Round of Talks: Pixel Perfect and Perfectly Insane: About That Time My Brain Just Switched Off Remy Parzinski, Design System Lead at Logius Learn from Remy how you can care for yourself because we all need to. Second Round of Talks: Becoming a LeadDev at your client; How to Fail at Large (or How to Do Slightly Better) Arno Koehler Engineering Manager @ iO What are the things that will help you become a lead engineer? Building Team Culture (Tales of trust and positivity) Michel Blankenstein Engineering Manager @ iO & Head of Technology @ Zorggenoot How do you create a culture at your company or team? RSVP now to secure your spot, and let's explore the fascinating world of design systems together!

    | Coven of Wisdom - Amsterdam

    Go to page for The Leadership Meetup

Share