2

jott - images_as_emoji

 1 year ago
source link: https://jott.live/markdown/images_as_emoji
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Over-Engineering an Emoji Webcam Filter with a Neural Network

by @bwasti (mastodon)


ASCII rendered images have been around for a very long time. Over the holidays I decided to take the logical step and render with a far more rich set of codepoints in the unicode space: emoji!

Here's a screen recording of me fidelity testing with some album art (stills at the bottom):

209887791-65fd8e66-e95a-4e46-95db-552972540903.gif

All the code: https://github.com/bwasti/emojicam

The Approach

Unlike ASCII, emoji have a lot of structure and color. Capturing all of that in a programmatic way is arduous, so I didn't attempt any sort of lookup table.

Instead, I trained a tiny neural network that could be invoked reasonably quickly for a "real-time" feel. I decide a 54x72 grid of emoji would be enough to render with reasonable fidelity and a framerate of 5fps would be acceptable.

So, 54×72×5=19440 invocations per second. On my M1 I get around 1Tflops of performance in raw matrix multiplications, so that's a budget of 50Mflops per inference at peak. Easily enough! Unfortunately, though, this requires the execution of the network to be server-side.

The Details

I contribute to a project called Shumai, a TypeScript based tensor library, so I used that.

Biases aside, TypeScript is pretty ideal for this project because (1) it's much faster than Python for training with tiny data (thousands of 36x36x3 images) and (2) I was able to get webcam integration through a browser + server in couple dozen lines of native code.

I unfortunately don't have a set of pixel patches classified by their respective "closest emoji," so I had to generate the dataset with some hacks. The basic idea was to take a reference image:

and mutate it by changing the color, rotating slightly, recropping, and either blurring or sharpening it:

The code to do that looks something like this:

import { Image } from '@shumai/image'
import * as sm from '@shumai/shumai'

function mutateColor(img) {
  let t = img.tensor().astype(sm.dtype.Float32)
  const r = 1 + (Math.random() - 0.5) / 5
  const g = 1 + (Math.random() - 0.5) / 5
  const b = 1 + (Math.random() - 0.5) / 5
  const s = 1 + (Math.random() - 0.5) / 5
  t = t
    .div(sm.scalar(255))
    .mul(sm.tensor(new Float32Array([r, g, b, 1])))
    .mul(sm.scalar(s))
    .mul(sm.scalar(255))
  t = sm.clamp(t, 0, 255)
  return new Image(t)
}

function randomCrop(img) {
  // max 20% shift
  const scale = 1 + Math.random() / 5
  const offset_x = Math.floor(Math.random() * img.width * (scale - 1))
  const offset_y = Math.floor(Math.random() * img.height * (scale - 1))
  return img.resize(scale).crop(offset_x, offset_y, img.width, img.height)
}

for (const base_img of base_imgs) {
  let img = base_img.rotate((Math.random() - 0.5) * 20)
  const scale = h / img.height
  img = img.resize(scale)
  img = mutateColor(img)
  if (Math.random() > 0.5) {
    img = img.gaussblur(Math.random() / 2)
  }

  img = randomCrop(img)
  img = img.flatten(30, 30, 30)
  // ...
}

(code listing)

Modeling

The next step was to write a model.

As I mentioned above, I had a pretty large flop budget, but that would only apply if the model was running at near peak. Unfortunately, the AMX unit on an M1 (which gives the CPU over a Tflop of performance) isn't used by Shumai's convolutions (yet!). As a result, true peak for some of the operators would be closer to 100Gflops. That means the budget is only 5Mflops per inference and I ended up being extremely convservative with it:

import * as sm from '@shumai/shumai'

export class EmojiClassifier extends sm.module.Module {
  constructor(num_emojis) {
    super()
    this.conv0 = sm.module.conv2d(3, 8, 3, { stride: 2 })
    this.conv1 = sm.module.conv2d(8, 32, 3, { stride: 2 })
    this.linear = sm.module.linear(32, num_emojis)
    this.skip_softmax = false
  }
  forward(x) {
    x = sm.avgPool2d(x, 2, 2, 2, 2)
    x = this.conv0(x)
    x = x.tanh()
    x = sm.avgPool2d(x, 2, 2, 2, 2)
    x = this.conv1(x)
    x = x.tanh()
    x = x.reshape([x.shape[0], 32])
    x = this.linear(x)
    if (this.skip_softmax) {
      return x
    }
    return x.softmax(1)
  }
}

(that's actually the full file)

The softmax at the end is only necessary for training, and I drop it from inference runs. The reason I can drop it is because the equation preserves argmax, which is all we care about at the end of the day (the index of the predicted emoji).

Training

To train I used mean squared error as a loss function. This turned out converge much more quickly than cross entropy.

const loss_fn = sm.loss.mse
const optim = new sm.optim.Adam(1e-3)

for (const iter of sm.util.viter(20000, dump)) {
  const [X, Y] = getBatch([2, 4, 4])
  const Y_hat = model(X)
  const loss = loss_fn(Y_hat, Y)
  if (ema_loss === 0) {
    ema_loss = loss.toFloat32()
  }
  ema_loss = 0.9 * ema_loss + 0.1 * loss.toFloat32()
  const grads = loss.backward()
  optim.step(grads)
}

(code listing)

The weird getBatch call above does nothing more than get a batch of tensors from the dataset. The reason it's weird is because ArrayFire (and thus Shumai) inexplicably doesn't support concatentation of more than 4 tensors at once. To get around this, I implemented the batch function to concat recursively. Luckily Bun's JIT makes this super fast so it's not a huge deal.

sm.util.viter is a lot like tqdm.tqdm but it takes a callback to print extra information. The dump function passed as a callback here has the loss, accuracy and an example classification. It's pretty cool to see the network slowly learn structure before classifying correctly most of the time. For example, it'd misclassify flags for other flags or faces for other faces at first before converging to be quite accurate (over 95% on 1700 classes).

Serving

Now for the fun part!

Once you hook up the webcam to a video element on a page, the actual client logic is extremely simple:

context.drawImage(video, 0, 0, canvas.width, canvas.height);
const data = context.getImageData(0, 0, canvas.width, canvas.height).data;
const response = await fetch(window.location.href, {
  method: 'POST',
  body: data
})
const emojis = await response.json()

(full code)

I send the raw pixels to the server with fetch and then await a response. The response from the server is a list of length 3,888 containing emoji that can be rendered in a grid.

The server makes use of Bun's serve method. Below is the full code without imports:

const emoji = JSON.parse(Bun.readFile('emoji.json'))
const model = new EmojiClassifier(emoji.length)
model.skip_softmax = true
model.checkpoint('weights') // this loads all the weights

serve({
  async fetch(req) {
    if (req.method === 'POST') {
      const response = []
      const data = new Uint8Array(await req.arrayBuffer())

      const t0 = performance.now()

      sm.util.tidy(() => {
        // convert data to a float tensor
        let t = sm.tensor(data).astype(sm.dtype.Float32)
        const height = 54
        const width = 72

        // mask out the alpha channel
        t = t.reshape([height * 36, width * 36, 4]).index([':', ':', ':3'])

        // arrayfire only supports dim <= 4 :( so we have to hack
        t = t.reshape([height, 36, width, 36 * 3])
        t = t.transpose([0, 2, 1, 3])
        t = t.reshape([height * width, 36, 36, 3])
        t = t.transpose([0, 3, 1, 2]).div(sm.scalar(255)).eval()

        // run the model
        const out = model(t).argmax(1)

        // convert indices back to emoji (strings)
        for (const i of out.toInt32Array()) {
          response.push(emoji[i])
        }
      })

      const t1 = performance.now()
      sm.util.tuiLoad(`${1e3 / (t1 - t0)} iters/sec`)

      // all set to return!
      return new Response(JSON.stringify(response))
    }
    return new Response(file('./index.html'))
  }
})

(code listing)

If a POST request is sent, the server blindly converts the raw data to a tensor (and casts it to float32). The call to sm.util.tidy isn't strictly necessary, but it provides a hint to the garbage collector that there is no tensor re-use across runs. This keeps memory pressure stable.

Then, it strips out the alpha channel (the index call) and transposes axes around to make a batched grid of 36x36 pixel patches. The weird extra transpose call is because ArrayFire doesn't support more than 4 dimensions (bleh).

Finally, the model is run, argmax is taken on the trailing dim, and the result is used to create an emoji response list. Done!

The Results

209889804-dca40a2e-d6a4-458d-a883-8ca26a81a382.png

You can see the learned sub-structures in the "12".

209889802-e87faee3-a99c-4c0a-9827-fb782e71636c.png

No skin-tone emoji was used, so the network chooses varying degrees of pink colors for my arm.

209889801-15dda3d9-f404-47f7-9812-9c04482eaf60.png

An emoji rendered with more emoji.

209889797-f2749dad-9925-4283-b257-3229176c9c77.png

I thought the rainbow looked cool here.

Thanks for reading!

The full code is open source and can be found here: https://github.com/bwasti/emojicam

If you're interested in following (I post mostly about machine learning, performance and a love of {Java,Type}Script), I use Twitter and Mastodon!


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK