On-the-Fly Machine Learning in the Browser with TensorFlow.js
On-the-Fly Machine Learning in the Browser with TensorFlow.js
By Dave Bitter
10 min read
TensorFlow.js is an incredibly powerful JavaScript library for training and deploying machine learning models in the browser and Node. js. Let’s explore this library by building a teachable machine!
- Authors
- Name
- Dave Bitter
- linkedinDave Bitter
- twitter@dave_bitter
- Github
- githubDaveBitter
- Website
- websiteBlog
I first came across the teachable machine while researching TensorFlow.js. Quite frankly, the idea of this working seemed insane to me. Are you telling me I can run machine learning in the browser where I can train the model on the fly? You sure can! Luckily, Google provides a Codelab where you build a simple version yourself which I followed loosely.
The demo I'm going to build
I'm going to build a demo showing how you can apply transfer learning in the browser using TensorFlow.js in realtime:
As you can see, I take images for the three classes. One neutral, one with my phone in my hand and one with a coffee in my hand. Then when I click "Train & Predict" it trains in a fraction of a second. Finally, for each class, in real time you can see how confident it is that the class matches. Naturally, you can provide more diverse images so it gets better and better at classifying.
What is Tensorflow.js?
TensorFlow.js is a JavaScript library developed by Google that allows you to define, train, and run machine learning models entirely in the browser, using JavaScript and a high-level layers API. It is part of the TensorFlow ecosystem, which includes a range of tools for machine learning applications. The advantage of TensorFlow.js is that it allows machine learning models to be run in the browser (or in Node.js), making machine learning more accessible to JavaScript developers and allowing for real-time interaction with the user.
Some of the benefits are:
- Privacy - you can both train and classify data on the machine of the user instead of sending it over to a server
- Speed - you can directly classify data instead of having to send it over to a remote server
- Sensor access - you can have access to the sensors of the user’s device like their camera and microphone
- Deploy - you just have to deploy the web application and don’t have to mess with complex server-side setups for machine learning
- Cost - last but not least, you dramatically cut down costs by just hosting a (static) website and using the user’s machine for the machine learning part
How does a teachable machine work?
A teachable machine basically takes an existing (or base) model and uses it on a similar but different domain. This is known as transfer learning. As humans, we do this all the time. We have a lifetime of experiences that we can use to recognize things we have never seen before. This example in the Codelab explains this best in my opinion. This is a willow tree:
There is a chance you might have never seen this type of tree before. Now that I have shown you, find the willow tree in this image:
You already have neurons in your brain that know how to identify objects that look like trees and long straight lines. You can use that knowledge to quickly classify the willow tree in this image as it is a tree, has long straight lines and you’ve learned that a willow tree looks like that.
Using the MobileNet base model
We need a model that is trained at classifying objects that we can then use to teach new things. Luckily this base model exists. MobileNet is a popular model that performs image recognition on 1000 different types of objects. It was trained on a huge dataset called ImageNet which has millions of labelled images. This model learned to spot common features among 1000 objects during training. Many of these features, like lines, textures, and shapes, can help identify new objects it hasn't seen before.
Let’s start building
For my demo, I bootstrapped a Next.js project and therefore the code examples might show React.js code. Naturally, this is just my preference for reactive web applications and you can use any (or no) framework.
Loading TensorFlow.js and MobileNet
First, I added TensorFlow.js through NPM with the @tensorflow/tfjs package. I then created a small hook where I load MobileNet:
export const useMobileNet = ({ tf, numberOfClasses }) => {
const model = useRef(null)
const mobileNet = useRef(null)
const [readyToTrainAndPredict, setReadyToTrainAndPredict] = useState(false)
useEffect(() => {
if (!tf) {
return
}
async function loadMobileNetFeatureModel() {
// ...
}
loadMobileNetFeatureModel().then(() => {
// ...
})
}, [tf])
return {
model: model.current,
mobileNet: mobileNet.current,
readyToTrainAndPredict,
}
}
It receives two props. tf
(TensorFlow.js) and numberOfClasses
which is the number of different classes it should detect and categorise. Let’s dive into the loadMobileNetFeatureModel
function:
async function loadMobileNetFeatureModel() {
const URL =
'https://tfhub.dev/google/tfjs-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/default/1'
mobileNet.current = await tf.loadGraphModel(URL, {
fromTFHub: true,
})
// Warm up the model by passing zeros through it once.
tf.tidy(function () {
mobileNet.current?.predict(tf.zeros([1, 224, 224, 3]))
})
}
I load the model from TFHub which is why I need to let TensorFlow.js know through fromTFHub: true
. To warm the model up, I pass zeros with a few specific values. 1
is the batch size, 244
is both the width and the height of the image and 3
is the number of colour channels (reg, green and blue). Once the MobileNet feature model is loaded I define the model head:
loadMobileNetFeatureModel().then(() => {
model.current = tf.sequential()
model.current.add(
tf.layers.dense({
inputShape: [1024],
units: 128,
activation: 'relu',
})
)
model.current.add(
tf.layers.dense({
units: numberOfClasses,
activation: 'softmax',
})
)
model.current.summary()
model.current.compile({
optimizer: 'adam',
loss: numberOfClasses === 2 ? 'binaryCrossentropy' : 'categoricalCrossentropy',
metrics: ['accuracy'],
})
setReadyToTrainAndPredict(true)
})
Quite a few things are happening here. Let’s go over them. First, I'm setting up a model that learns patterns in data to make predictions. Then, I'm adding an input layer, which acts as a first step in understanding the data. I'm using 128 "neurons" to analyze the data, and I'm applying a function called ReLU to understand it better. Next, I'm creating the output layer where the model makes its predictions. It looks at the patterns it learned and makes guesses based on them. Here, I'm using a function called softmax
to ensure it selects one option from all the possibilities. After that, I'm checking what the model looks like so far by generating a model summary. Finally, before the model starts learning, I'm setting up how it should learn. I'm instructing it to improve its guesses over time and defining how it can evaluate its performance, such as by checking its accuracy.
Using the user’s webcam to gather training data
Next, I want to be able to gather images to use in the demo. I could just let the user upload images, but let’s make it easier and allow them to snap images with their webcam to use. Yet again, I create a small hook to give us access:
export const useWebcam = () => {
const videoRef = useRef(null)
const [videoPlaying, setVideoPlaying] = useState(false)
const toggleWebcam = async () => {
if (videoPlaying) {
videoRef.current.srcObject?.getVideoTracks().forEach((track) => {
track.stop()
})
setVideoPlaying(false)
return
}
if (!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia)) {
// getUsermedia parameters.
const constraints = {
video: true,
width: 640,
height: 480,
}
// Activate the webcam stream.
navigator.mediaDevices.getUserMedia(constraints).then(function (stream) {
videoRef.current.srcObject = stream
videoRef.current.addEventListener('loadeddata', function () {
setVideoPlaying(true)
})
})
} else {
console.warn('getUserMedia() is not supported by your browser')
}
}
useEffect(() => {
navigator.permissions.query({ name: 'camera' }).then(function (permissionStatus) {
if (permissionStatus.state === 'granted') {
toggleWebcam()
}
})
}, [])
return {
toggleWebcam,
videoRef,
videoPlaying,
}
}
With this hook, I can toggle the webcam on and off and have access to the video reference so I place it on an HTML video
element:
const { toggleWebcam, videoPlaying, videoRef } = useWebcam()
// ...
;<video ref={videoRef} />
Next, for each class, we add a button to the UI to capture data. Once the user clicks this button, I call the gatherDataForClass
function:
const [capturedImages, setCapturedImages] = useState({})
const [isTraining, setIsTraining] = useState(false)
const [trainingData, setTrainingData] = useState([])
const gatherDataForClass = (classNumber) => {
let imageFeatures = tf?.tidy(function () {
let videoFrameAsTensor = tf.browser.fromPixels(videoRef.current)
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [224, 224], true)
let normalizedTensorFrame = resizedTensorFrame.div(255)
return mobileNet.predict(normalizedTensorFrame.expandDims()).squeeze()
})
setTrainingData([...trainingData, { input: imageFeatures, output: classNumber }])
const canvas = document.createElement('canvas')
canvas.width = videoRef.current.videoWidth
canvas.height = videoRef.current.videoHeight
canvas.getContext('2d').drawImage(videoRef.current, 0, 0)
const capturedImage = canvas.toDataURL('image/png')
setCapturedImages((prevImages) => {
const classKey = `class_${classNumber + 1}`
return {
...prevImages,
[classKey]: [capturedImage, ...prevImages[classKey]],
}
})
}
Using tf.browser.fromPixels()
I grab a picture from the webcam feed. Next, I resize the videoFrameAsTensor
variable to be of the correct shape for the MobileNet model's input. Remember, in loadMobileNetFeatureModel
we set this to 224 by 224. I store this training data for this class in an array for later use.
It might also be nice to show the user which images are being used in which class. I use canvas to create a data URL for the image so I can show it in the UI.
Training the model on-the-fly
Once the user has captured data (images) for the classes, it’s time to train the model with this new data.
const [progressBarValues, setProgressBarValues] = useState()
const shouldPredict = useRef(false)
function predictLoop() {
// ...
}
const trainAndPredict = async () => {
setIsTraining(true)
shouldPredict.current = false
const trainingDataInputs = trainingData.map((data) => data.input)
const trainingDataOutputs = trainingData.map((data) => data.output)
tf.util.shuffleCombo(trainingDataInputs, trainingDataOutputs)
let outputsAsTensor = tf.tensor1d(trainingDataOutputs, 'int32')
let oneHotOutputs = tf.oneHot(outputsAsTensor, 3)
let inputsAsTensor = tf.stack(trainingDataInputs)
await model.fit(inputsAsTensor, oneHotOutputs, {
shuffle: true,
batchSize: 5,
epochs: 10,
})
outputsAsTensor.dispose()
oneHotOutputs.dispose()
inputsAsTensor.dispose()
setIsTraining(false)
shouldPredict.current = true
predictLoop()
}
You can call the trainAndPredict
function that will take the training data we stored before and pass it to the model. First, we shuffle the training data to ensure that the order of the data does not cause any issues during training. Next, I convert the outputs to tensors. I can then pass them to the tf.oneHot()
function along with the max number of classes which in the case of my demo is 3. Next, I convert the input tensors to become regular 2D tensors using the tf.stack()
function. I can now finally train the model head using the model.fit()
function where I pass the tensors. Finally, I can dispose of the created tensors as the model is trained and I don’t need them anymore.
Time to call the predictLoop
function that will continuously grab a frame of the webcam, and predict in which class it falls:
function predictLoop() {
tf?.tidy(function () {
if (!shouldPredict.current) {
return
}
let videoFrameAsTensor = tf.browser.fromPixels(videoRef.current).div(255)
let resizedTensorFrame = tf.image.resizeBilinear(videoFrameAsTensor, [224, 224], true)
let imageFeatures = mobileNet.predict(resizedTensorFrame.expandDims())
let prediction = model.predict(imageFeatures).squeeze()
let predictionArray = prediction.arraySync()
setProgressBarValues(predictionArray)
})
window.requestAnimationFrame(predictLoop)
}
As this is a recursive function using window.requestAnimationFrame
we first check if we need to break out of the loop. Next we grab the current frame of the webcam similarly to how we previously grabbed images to gather data for the classes. Next, we essentially call the predict
method on the model to get a value for each class from 0 to 1 on how confident it is that the class matches. The results might look something like [0.39075496792793274, 0.6091347932815552, 0.00011020545935025439]
indicating that this is most likely the second class with hints of the first class. Finally, we take these values and store them so we can display progress bars under each class to visualise how certain it is that the current webcam frame is for that class.
The final result
The above logic, combined with some extra application logic results in the demo you can try out here (probably best on desktop). The final code can be found over at my GitHub.
Some things I noticed
While building this demo, I noticed a few things. Firstly, this stuff is hard! I don’t have a machine learning background so many of the techniques and concepts I was not familiar with. While doing the Codelab I read things like "Now it's time to define your model head, which is essentially a very minimal multi-layer perceptron." which didn’t fill me with much confidence that I was able to build this demo. Luckily, with the help of the Codelab and TensorFlow.js’s documentation of their utilities to make this stuff easier for you, I was able to make something cool!
I also noticed how accessible this actually is for developers. Being able to run this in the browser on a relatively "low-powered" machine like my laptop (in machine learning terms) is just amazing. I’m able to share a link to this static website and let people play around with this and train models which is amazing!
Finally, I noticed that this is a gateway into machine learning for me. It hits the right balance between having familiar concepts and being a bit out of my depth and wanting to learn about these techniques in this exciting field. Try one of these Codelabs out. It might be a lot less daunting than you suspect.
Cool demo, what are actual use cases?
As the demo and possibilities are so wide, it’s hard to see actual use cases at first. Just focussing on image recognition and transfer learning (which is a tiny part of what machine learning can offer) I can think of quite a few use cases:
- Vehicle Damage Assessment in Insurance Apps - Users can upload photos of damaged vehicles, and the app can classify the type and severity of damage, facilitating the claims process for insurance companies.
- Vehicle Condition Assessment for Lease Returns - Leaseholders can upload photos of leased vehicles for inspection before returning them. The app can classify the condition of the vehicle, helping to determine if any additional charges are warranted.
- Property Identification for Mortgage Assessment - Analyze uploaded photos of properties to classify them into residential, commercial, or mixed-use categories, assisting in property valuation and mortgage assessment.
You can kind of see where I’m going with this. Even scoped to this tiny part of what is possible with machine learning, it’s an incredibly useful tool we can offer to our users. in the end, it’s a tool that supplies us with data. We can then use that data to offer awesome capabilities and user experiences!
Upcoming events
Drupal CMS Launch Party
Zoals sommigen misschien weten wordt op 15 Januari een nieuwe distributie van Drupal gelanceerd. Namelijk Drupal CMS (ook wel bekend als Starshot). Om dit te vieren gaan we op onze campus een klein eventje organiseren. We gaan die dag samen de livestream volgen waarbij het product gelanceerd wordt. De agenda is als volgt: 17u – 18u30: Drupal CMS livestream met taart 18u30 – 19u00: Versteld staan van de functionaliteiten 19u – 20u: Pizza eten en verder versteld staan van de functionaliteiten Laat ons zeker weten of je komt of niet door de invite te accepteren! Tot dan!
| Coven of Wisdom Herentals
Go to page for Drupal CMS Launch PartyCoven of Wisdom - Herentals - Winter `24 edition
Worstelen jij en je team met het bouwen van schaalbare digitale ecosystemen of zit je vast in een props hell met React of in een ander framework? Kom naar onze meetup waar ervaren sprekers hun inzichten en ervaringen delen over het bouwen van robuuste en flexibele applicaties. Schrijf je in voor een avond vol kennis, heerlijk eten en een mix van creativiteit en technologie! 🚀 18:00 – 🚪 Deuren open 18:15 – 🍕 Food & drinks 19:00 – 📢 Building a Mature Digital Ecosystem - Maarten Heip 20:00 – 🍹 Kleine pauze 20:15 – 📢 Compound Components: A Better Way to Build React Components - Sead Memic 21:00 – 🙋♀️ Drinks 22:00 – 🍻 Tot de volgende keer? Tijdens deze meetup gaan we dieper in op het bouwen van digitale ecosystemen en het creëren van herbruikbare React componenten. Maarten deelt zijn expertise over het ontwikkelen van een volwassen digitale infrastructuur, terwijl Sead je laat zien hoe je 'From Props Hell to Component Heaven' kunt gaan door het gebruik van Compound Components. Ze delen praktische inzichten die je direct kunt toepassen in je eigen projecten. 📍 Waar? Je vindt ons bij iO Herentals - Zavelheide 15, Herentals. Volg bij aankomst de borden 'meetup' vanaf de receptie. 🎫 Schrijf je in! De plaatsen zijn beperkt, dus RSVP is noodzakelijk. Dit helpt ons ook om de juiste hoeveelheid eten en drinken te voorzien - we willen natuurlijk niet dat iemand met een lege maag naar huis gaat! 😋 Over iO Wij zijn iO: een groeiend team van experts die end-to-end-diensten aanbieden voor communicatie en digitale transformatie. We denken groot en werken lokaal. Aan strategie, creatie, content, marketing en technologie. In nauwe samenwerking met onze klanten om hun merken te versterken, hun digitale systemen te verbeteren en hun toekomstbestendige groei veilig te stellen. We helpen klanten niet alleen hun zakelijke doelen te bereiken. Samen verkennen en benutten we de eindeloze mogelijkheden die markten in constante verandering bieden. De springplank voor die visie is talent. Onze campus is onze broedplaats voor innovatie, die een omgeving creëert die talent de ruimte en stimulans geeft die het nodig heeft om te ontkiemen, te ontwikkelen en te floreren. Want werken aan de infinite opportunities van morgen, dat doen we vandaag.
| Coven of Wisdom Herentals
Go to page for Coven of Wisdom - Herentals - Winter `24 editionThe Test Automation Meetup
PLEASE RSVP SO THAT WE KNOW HOW MUCH FOOD WE WILL NEED Test automation is a cornerstone of effective software development. It's about creating robust, predictable test suites that enhance quality and reliability. By diving into automation, you're architecting systems that ensure consistency and catch issues early. This expertise not only improves the development process but also broadens your skillset, making you a more versatile team member. Whether you're a developer looking to enhance your testing skills or a QA professional aiming to dive deeper into automation, RSVP for an evening of learning, delicious food, and the fusion of coding and quality assurance! 🚀🚀 18:00 – 🚪 Doors open to the public 18:15 – 🍕 Let’s eat 19:00 – 📢 First round of Talks 19:45 – 🍹 Small break 20:00 – 📢 Second round of Talks 20:45 – 🍻 Drinks 21:00 – 🙋♀️ See you next time? First Round of Talks: The Power of Cross-browser Component Testing - Clarke Verdel, SR. Front-end Developer at iO How can you use Component Testing to ensure consistency cross-browser? Overcoming challenges in Visual Regression Testing - Sander van Surksum, Pagespeed | Web Performance Consultant and Sannie Kwakman, Freelance Full-stack Developer How can you overcome the challenges when setting up Visual Regression Testing? Second Round of Talks: Omg who wrote this **** code!? - Erwin Heitzman, SR. Test Automation Engineer at Rabobank How can tests help you and your team? Beyond the Unit Test - Christian Würthner, SR. Android Developer at iO How can you do advanced automated testing for, for instance, biometrics? RSVP now to secure your spot, and let's explore the fascinating world of test automation together!
| Coven of Wisdom - Amsterdam
Go to page for The Test Automation Meetup