....

.

.............

.......................

...................................

.......................................

...............................................

...........................................................

.......................................................................

...................................................................................

........................................................................

...........................................................

........................................

...................................

................................

..................

...........

...........

......

....

.

Bursting fireworks with Metal

January 26, 2025

Today we are finally ready to recreate this animation once again.

If you missed it, last time we did an attempt of implementing these fireworks using fragment shaders only. Although we got a good result, it lacked a bit of details and performance.

This time we'll do it by using more of MetalKit's features. We will generate a fireworks grid, adjust vertex parameters and apply post-processing to get the glow effect. I'm thrilled with the end result.

Even if you don't implement this animation in a real application, I guarantee that the experience gained from this article will help you become more fluent with MetalKit tools.

If you are a tldr person, then all the links, source code and an example of the end result are at the end of the article.

Initial Setup

Each such manipulation with MetalKit starts with defining a lot of boilerplate code. To draw a firework scene we define a render pipeline state. Buffers are required to pass the data to shaders calls. Textures will be used for post-processing of the final result.

To make this code work on both macOS and iOS, we define typealias to shim the required base class.

import MetalKit
import SwiftUI
 
#if os(macOS)
typealias ViewController = NSViewController
#else
typealias ViewController = UIViewController
#endif
 
final class FireworkViewController: ViewController {
  private var device: (any MTLDevice)!
  private var commandQueue: (any MTLCommandQueue)!
 
  private var scenePipelineState: (any MTLRenderPipelineState)!
 
  private var vertexBuffer: (any MTLBuffer)!
  private var progressBuffer: (any MTLBuffer)!
  private var uniformBuffer: (any MTLBuffer)!
 
  private var sceneTexture: (any MTLTexture)!
  private var glowTexture: (any MTLTexture)!
 
  private lazy var canvasView = MTKView()
}

The most of the environment configuration happens in loadView and viewDidLoad. We also subscribe to layout changes to properly update textures to reflect new view sizes.

extension FireworkViewController {
  override func loadView() {
    view = canvasView
  }
 
  override func viewDidLoad() {
    super.viewDidLoad()
 
    device = MTLCreateSystemDefaultDevice()
    commandQueue = device.makeCommandQueue()
    canvasView.device = device
    canvasView.delegate = self
 
    buildPipelineStates()
    buildBuffers()
  }
 
  #if os(macOS)
  override func viewDidLayout() {
    super.viewDidLayout()
 
    buildResources(size: canvasView.bounds.size)
  }
  #else
  override func viewDidLayoutSubviews() {
    super.viewDidLayoutSubviews()
 
    buildResources(size: canvasView.bounds.size)
  }
  #endif
}

To silence the delegate error for now we define empty implementations of the MTKViewDelegate methods.

extension FireworkViewController: MTKViewDelegate {
  func mtkView(
    _ view: MTKView,
    drawableSizeWillChange size: CGSize
  ) {}
  
  func draw(in view: MTKView) {}
}

Configure pipeline descriptor with the settings allowing proper blending with background.

private extension FireworkViewController {
  func buildPipelineStates() {
    guard
      let library = device.makeDefaultLibrary()
    else {
      return
    }
 
    let scenePipelineDescriptor = MTLRenderPipelineDescriptor()
    scenePipelineDescriptor.vertexFunction = library.makeFunction(name: "Firework::vertexScene")
    scenePipelineDescriptor.fragmentFunction = library.makeFunction(name: "Firework::fragmentScene")
    scenePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm
    scenePipelineDescriptor.colorAttachments[0].isBlendingEnabled = true
    scenePipelineDescriptor.colorAttachments[0].rgbBlendOperation = .add
    scenePipelineDescriptor.colorAttachments[0].alphaBlendOperation = .add
    scenePipelineDescriptor.colorAttachments[0].sourceRGBBlendFactor = .sourceAlpha
    scenePipelineDescriptor.colorAttachments[0].sourceAlphaBlendFactor = .sourceAlpha
    scenePipelineDescriptor.colorAttachments[0].destinationRGBBlendFactor = .oneMinusSourceAlpha
    scenePipelineDescriptor.colorAttachments[0].destinationAlphaBlendFactor = .oneMinusSourceAlpha
 
    do {
      scenePipelineState = try device.makeRenderPipelineState(descriptor: scenePipelineDescriptor)
    } catch {
      fatalError("Failed to create scene pipeline state: \(error)")
    }
  }
}

Initialise the buffers with some amount of memory. For simplicity we put a fairly large value here, but ideally we should calculate it based on the actual number of vertices we will be using.

private extension FireworkViewController {
  ...
 
  func buildBuffers() {
    vertexBuffer = device.makeBuffer(
      length: 10 * 1000 * 1000
    )
 
    progressBuffer = device.makeBuffer(
      length: 10 * 1000 * 1000
    )
 
    uniformBuffer = device.makeBuffer(
      length: 10 * 1000 * 1000
    )
  }
}

The resources are the texture we'll use later when we're doing the glow addition. Additionally, we define a convenient descriptor constructor that we will be a part of drawing calls.

private extension FireworkViewController {
  ...
 
  func buildResources(size: CGSize) {
    let width = Int(size.width)
    let height = Int(size.height)
 
    let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
      pixelFormat: .bgra8Unorm,
      width: width,
      height: height,
      mipmapped: false
    )
    textureDescriptor.usage = [.renderTarget, .shaderRead, .shaderWrite]
 
    sceneTexture = device.makeTexture(descriptor: textureDescriptor)
    glowTexture = device.makeTexture(descriptor: textureDescriptor)
  }
 
  func sceneRenderPassDescriptor(_ texture: any MTLTexture) -> MTLRenderPassDescriptor? {
    let descriptor = MTLRenderPassDescriptor()
    descriptor.colorAttachments[0].texture = texture
    descriptor.colorAttachments[0].loadAction = .clear
    descriptor.colorAttachments[0].storeAction = .store
    descriptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1)
    return descriptor
  }
}

Also add some utility code to wrap the view controller instance into SwiftUI-friendly representation.

#if os(macOS)
 
struct FireworkView: NSViewControllerRepresentable {
  func makeNSViewController(context: Context) -> some NSViewController {
    FireworkViewController()
  }
 
  func updateNSViewController(
    _ nsViewController: NSViewControllerType,
    context: Context
  ) {}
}
 
#else
 
struct FireworkView: UIViewControllerRepresentable {
  func makeUIViewController(context: Context) -> some UIViewController {
    FireworkViewController()
  }
 
  func updateUIViewController(
    _ uiViewController: UIViewControllerType,
    context: Context
  ) {}
}
 
#endif

Firework Mesh

A mesh is a set of primitives used by Metal to create drawings. These primitives are defined by vertices, which are essentially points with some data attached to them.

We will be generating a mesh in a separate structure. Start by defining a set of properties used in calculations. Uniform subtype is a special set of constant values which we will be passing directly to the fragment shader.

struct FireworkScene {
  struct Uniform {
    let color: SIMD4<Float>
  }
 
  var trailStartPoint: SIMD2<Float>
  var trailEndPoint: SIMD2<Float>
  var trailControlPoint: SIMD2<Float>
  var ratio: Float
  var segments: Int = 8
  var trailWidth: Float = 0.05
}

Add the draw method that will hold calculations for multiple parts of a firework - a launch trail and a burst palm.

extension FireworkScene {
  func draw(
    vertexBuffer: inout ConstantBuffer<Float>,
    progressBuffer: inout ConstantBuffer<Float>
  ) {
    drawLaunchTrail(
      vertexBuffer: &vertexBuffer,
      progressBuffer: &progressBuffer
    )
  }
}

To simplify memory management and value handling in MTLBuffer, we'll add a ConstantBuffer type. It defines the logic for writing both random values and SIMD vectors into the buffer

import MetalKit
 
struct ConstantBuffer<T> {
  let length: Int
  var data: UnsafeMutablePointer<T>
  var position: Int
 
  static var stepSize: Int {
    MemoryLayout<T>.stride
  }
 
  init (_ buffer: MTLBuffer) {
    let dataPtr = buffer.contents()
    let floatPtr = dataPtr.bindMemory(
      to: T.self,
      capacity: buffer.length / Self.stepSize
    )
 
    self.init(
      buffer: floatPtr,
      elementsCount: buffer.length / Self.stepSize
    )
  }
  
  init(
    buffer: UnsafeMutablePointer<T>,
    elementsCount: Int
  ) {
    data = buffer
    length = elementsCount
    position = .zero
  }
 
  var availableSpace: UInt {
    UInt(length - position)
  }
 
  func hasSpace(for count: UInt) -> Bool {
    availableSpace >= count
  }
}
 
extension ConstantBuffer {
  mutating func write(
    value: inout T,
    instance: Int = 0
  ) {
    withUnsafeBytes(of: &value) { bytes in
      data[position] = bytes.load(as: T.self)
    }
    position = position &+ 1
  }
 
  mutating func append(_ value: T) {
    guard hasSpace(for: 1) else { return }
    appendRaw(value)
  }
 
  mutating func appendRaw(_ value: T) {
    data[position] = value
    position = position &+ 1
  }
}
 
extension ConstantBuffer where T: SIMDScalar {
  mutating func append(_ vector: SIMD3<T>) {
    append(vector.x)
    append(vector.y)
    append(vector.z)
  }
 
  mutating func append(_ vector: SIMD4<T>) {
    append(vector.x)
    append(vector.y)
    append(vector.z)
    append(vector.w)
  }
}

Drawing a trail requires calculating quadratic Bezier curve parameters for mesh vertices (positions, normals, tangents) and storing them in MTLBuffer instances. Basically we iterate over the specified amount of segments and calculate edge points for each segment.

Four edge points define a quad, one quad is defined by two triangles. These triangles are the primitives used by Metal to draw anything. We calculate these points and place them in the order that three points form a triangle.

trail mesh base

The dissolve progress values are tied to each segment. We multiply this value by 0.5 to modulate it within a single keyframe. Progress values are added for each vertex. Obviously, there is data duplication with this approach. To fix this, we could use indices, but we will leave this topic for the next article.

trail mesh progress

The application of the trail width is also interesting. Basically, when using a Bézier curve, we only calculate the points in the middle. To get the edge points, we calculate the curve tangent vector, which is essentially the direction of the curve.

trail mesh width

We then invert the coordinates to get the normal, which is a vector perpendicular to the tangent. From this we can use it to shift the middle point by any length.

trail mesh width

Add this implementation to the scene structure. Note that for the very first segment, we skip the calculations and assign the points directly. This is because for the first segment we don't have enough points to build a complete quad.

private extension FireworkScene {
  func drawLaunchTrail(
    vertexBuffer: inout ConstantBuffer<Float>,
    progressBuffer: inout ConstantBuffer<Float>
  ) {
    let tIncrement = 1.0 / Float(segments)
    var previousPoints: [SIMD2<Float>] = []
 
    for i in 0...segments {
      let t = Float(i) * tIncrement
 
      let currentPoint = quadraticBezier(
        trailStartPoint,
        trailControlPoint,
        trailEndPoint,
        t
      )
 
      let tangent = quadraticBezierTangent(
        trailStartPoint,
        trailControlPoint,
        trailEndPoint,
        t
      )
 
      let normalizedTangent = normalize(tangent)
      let normal = SIMD2<Float>(
        -normalizedTangent.y,
         normalizedTangent.x
      )
 
      let offset = -trailWidth / 2
 
      let p1 = currentPoint + normal * offset
      let p2 = currentPoint - normal * offset
 
      guard i > 0 else {
        previousPoints = [p1, p2]
        continue
      }
 
      let p_current = 0.5 * Float(i) / (Float(segments))
      let p_previous = 0.5 * Float(i - 1) / (Float(segments))
 
      appendVertex(previousPoints[0], to: &vertexBuffer)
      appendVertex(p1, to: &vertexBuffer)
      appendVertex(previousPoints[1], to: &vertexBuffer)
 
      progressBuffer.appendRaw(p_previous)
      progressBuffer.appendRaw(p_current)
      progressBuffer.appendRaw(p_previous)
 
      appendVertex(previousPoints[1], to: &vertexBuffer)
      appendVertex(p1, to: &vertexBuffer)
      appendVertex(p2, to: &vertexBuffer)
 
      progressBuffer.appendRaw(p_previous)
      progressBuffer.appendRaw(p_current)
      progressBuffer.appendRaw(p_current)
 
      previousPoints = [p1, p2]
    }
  }
}

Also add the utility methods for calculating curve points and filling vertex buffer with point values.

private extension FireworkScene {
  func quadraticBezier(
    _ p0: SIMD2<Float>,
    _ p1: SIMD2<Float>,
    _ p2: SIMD2<Float>,
    _ t: Float
  ) -> SIMD2<Float> {
    let oneMinusT = 1 - t
    return oneMinusT * oneMinusT * p0 + 2 * oneMinusT * t * p1 + t * t * p2
  }
 
  func quadraticBezierTangent(
    _ p0: SIMD2<Float>,
    _ p1: SIMD2<Float>,
    _ p2: SIMD2<Float>,
    _ t: Float
  ) -> SIMD2<Float> {
    let oneMinusT = 1 - t
    return 2 * oneMinusT * (p1 - p0) + 2 * t * (p2 - p1)
  }
 
  func appendVertex(
    _ point: SIMD2<Float>,
    to buffer: inout ConstantBuffer<Float>
  ) {
    buffer.appendRaw(point.x * ratio)
    buffer.appendRaw(point.y)
    buffer.appendRaw(0.0)
    buffer.appendRaw(1.0)
  }
}

Update the drawing call by instantiating buffers and data for the scene. Calling firework.draw fills the buffers with vertices data which is used later.

extension FireworkViewController: MTKViewDelegate {
  func draw(in view: MTKView) {
    var vBufferWrapper = ConstantBuffer<Float>(vertexBuffer)
    var pBufferWrapper = ConstantBuffer<Float>(progressBuffer)
    var uBufferWrapper = ConstantBuffer<FireworkScene.Uniform>(uniformBuffer)
 
    let firework = FireworkScene(
      trailStartPoint: SIMD2<Float>(-0.4, -0.8),
      trailEndPoint: SIMD2<Float>(0.1, 0.5),
      trailControlPoint: SIMD2<Float>(-0.3, -0.1),
      ratio: Float(view.bounds.height / view.bounds.width),
      trailWidth: 0.03
    )
 
    var uniform = FireworkScene.Uniform(
      color: SIMD4<Float>(208.0 / 255.0, 80.0 / 255.0, 111.0 / 255.0, 1.0)
    )
    uBufferWrapper.write(value: &uniform)
 
    firework.draw(
      vertexBuffer: &vBufferWrapper,
      progressBuffer: &pBufferWrapper
    )
  }
 
  ...
}

The rest of the method implementation is about linking buffers and the values to the Metal infrastructure. Note here that we call the drawPrimitives method with the instanceCount parameter. We'll come back to this a little later when we increase the number of drawn fireworks. Also, for now, we pass the drawable's texture instead of sceneTexture to the builder of render descriptor. This will be changed once we start working on the glow effect.

extension FireworkViewController: MTKViewDelegate {
  func draw(in view: MTKView) {
		...
 
    guard
      let commandBuffer = commandQueue.makeCommandBuffer(),
      let drawable = view.currentDrawable
    else { return }
 
    let vertexCount = vBufferWrapper.position / 4
 
    if
      let renderPassDescriptor = sceneRenderPassDescriptor(drawable.texture),
      let renderEncoder = commandBuffer.makeRenderCommandEncoder(
        descriptor: renderPassDescriptor
      )
    {
      renderEncoder.setRenderPipelineState(scenePipelineState)
 
      renderEncoder.setVertexBuffer(
        vertexBuffer,
        offset: 0,
        index: 0
      )
      renderEncoder.setVertexBuffer(
        progressBuffer,
        offset: 0,
        index: 1
      )
      renderEncoder.setVertexBuffer(
        uniformBuffer,
        offset: 0,
        index: 2
      )
 
      renderEncoder.drawPrimitives(
        type: .triangle,
        vertexStart: 0,
        vertexCount: vertexCount,
        instanceCount: 1
      )
 
      renderEncoder.endEncoding()
    }
 
    commandBuffer.present(drawable)
    commandBuffer.commit()
  }
 
  ...
}

In the second delegate method, we update the textures so that they always have the current size.

extension FireworkViewController: MTKViewDelegate {
	...
 
  func mtkView(
    _ view: MTKView,
    drawableSizeWillChange size: CGSize
  ) {
    buildResources(size: size)
  }
}

The Metal code at this step simply takes the values from the incoming data packets and puts them to fragments for processing. Here we replicate the Uniform structure defined earlier in FireworkScene.

#include <metal_stdlib>
using namespace metal;
 
namespace Firework {
  struct Uniform {
    packed_float4 color;
  };
 
  struct VertexOut {
    float4 position [[position]];
    float4 color;
    float progress;
  };
 
  vertex VertexOut vertexScene(
    uint vid [[ vertex_id ]],
    uint iid [[ instance_id ]],
    constant packed_float4* position [[ buffer(0) ]],
    constant float* progress [[ buffer(1) ]],
    constant Uniform* uniform [[ buffer(2) ]]
  ) {
    VertexOut out;
 
    out.position = position[vid];
    out.color = uniform[iid].color;
    out.progress = progress[vid];
 
    return out;
  }
 
  fragment float4 fragmentScene(
    VertexOut in [[stage_in]]
  ) {
    return in.color;
  }
}

After running the code you should see the following curved trail.

drawn trail

Dissolving Trail

Time to add some movement to our trail. Start by defining the new parameters, their names say about their purpose.

final class FireworkViewController: ViewController {
  ...
 
  let duration: Float = 3.0
  private let initialTime = CACurrentMediaTime()
}

Update draw call by calculating the elapsedTime parameter. Usage of truncatingRemainder helps us to loop the time within a specified duration period.

func draw(in view: MTKView) {
  ...
 
  let elapsedTime = Float(CACurrentMediaTime() - initialTime).truncatingRemainder(dividingBy: duration)
  let t = elapsedTime / duration
 
  var uniform = FireworkScene.Uniform(
    color: SIMD4<Float>(208.0 / 255.0, 80.0 / 255.0, 111.0 / 255.0, 1.0),
    time: t
  )
  uBufferWrapper.write(value: &uniform)
 
  ...
}

Add time parameter to the vertex structures and the logic of placing it to the vertex out data pack.

namespace Firework {
  struct Uniform {
    packed_float4 color;
    float time;
  };
 
  struct VertexOut {
    float4 position [[position]];
    float4 color;
    float progress;
    float time;
  };
 
  vertex VertexOut vertexScene(
    uint vid [[ vertex_id ]],
    uint iid [[ instance_id ]],
    constant packed_float4* position [[ buffer(0) ]],
    constant float* progress [[ buffer(1) ]],
    constant Uniform* uniform [[ buffer(2) ]]
  ) {
    VertexOut out;
 
    out.position = position[vid];
    out.color    = uniform[iid].color;
    out.progress = progress[vid];
    out.time = uniform[iid].time;
 
    return out;
  }
  
	...
}

At this step the fragment evaluation holds most of the changes.

We mentioned earlier that we have two keyframes, each taking up half of the progress of the entire fireworks animation. The dissolve animation must completely take place during one keyframe. To achieve this we calculate normalized representations of time and progress by modulating their values to the range 0...1. Here breakpoint parameter defines a point in a time where one keyframe should replace another.

We compare normalisedProgress with normalizedTime to avoid rendering fragments for which it is not yet time.

The step with _noise and noiseFade gives this dissolution effect. I have written about the implemenation of this effect in my previous article. Here we have a slightly different version with the ability to control the noise drop.

The last component of this shader is the bloom effect. It is needed here to simulate the degree of heat of the trail. Components closer to the trail head have a "high temperature" and are correspondingly lighter in color. As we move away from the head, the temperature drops and the luminescence decreases.

namespace Firework {
  ...
 
  float rand(float2 n) {
    return fract(sin(dot(n, n)) * length(n));
  }
   
  float noise(float2 n) {
    const float2 d = float2(0.0, 1.0);
   
    float2 b = floor(n);
    float2 f = smoothstep(float2(0.0), float2(1.0), fract(n));
   
    return mix(
      mix(rand(b),           rand(b + d.yx), f.x),
      mix(rand(b + d.xy),    rand(b + d.yy), f.x),
      f.y
    );
  }
 
  fragment float4 fragmentScene(
    VertexOut in [[stage_in]]
  ) {
    const float breakpoint = 0.5;
    const float noiseScale = 0.8;
    const float noiseFalloff = 3.0;
    const float bloomIntensity = 2.0;
    const float bloomFalloff = 20.0;
 
    const float time = in.time;
    const float dissolveProgress = in.progress;
    
    const bool isSecondPhase = time > breakpoint;
 
    const float normalizedTime = isSecondPhase ? (time - breakpoint) / (1.0 - breakpoint) : time / breakpoint;
    const float normalizedProgress = isSecondPhase ? (dissolveProgress - breakpoint) / (1.0 - breakpoint) : dissolveProgress / breakpoint;
 
    if (normalizedProgress > normalizedTime) discard_fragment();
 
    float _noise = 1.0 - noise(in.position.xy * noiseScale);
 
    float delayPeriod = isSecondPhase > 0.0 ? 0.0 : 0.25;
    float delayFactor = (pow(normalizedTime, noiseFalloff * (isSecondPhase > 0.0 ? 1.0 : 1.3)) - delayPeriod) / (1.0 - delayPeriod);
    float noiseFade = smoothstep(0.0, normalizedTime, normalizedProgress - delayFactor);
 
    if (_noise > noiseFade) discard_fragment();
 
    float colorFactor = clamp(1.0 - (normalizedTime - normalizedProgress), 0.0, 1.0);
    float bloomEffect = bloomIntensity * pow(colorFactor, bloomFalloff);
    float bloomFade = 1.0 - smoothstep(0.9, 1.0, normalizedTime);
 
    bloomEffect *= bloomFade;
    float3 finalColor = in.color.rgb + in.color.rgb * bloomEffect;
 
    return float4(finalColor, 1.0);
  }
}

Missing Palm

Now it's time to add a missing part of the firework - the palm. The trajectory calculations for each trail here are similar to the launch trail. I wrote about how trigonometry is involved in calculating end points of each trail in the previous article.

The gravity parameter allows us to control the speed at which the palm trails will "fall". Note that the progress is calculated with respect to the 0.5 delay.

private extension FireworkScene {
	...
 
  private func drawBurstPalm(
    vertexBuffer: inout ConstantBuffer<Float>,
    progressBuffer: inout ConstantBuffer<Float>
  ) {
    let tIncrement = 1.0 / Float(segments)
    let gravity: Float = 0.8
 
    for i in 0..<burstPalmTrails {
      let angle = (Float.pi * 2 * Float(i)) / Float(burstPalmTrails)
 
      let initialVelocity = SIMD2<Float>(
        cos(angle),
        sin(angle)
      ) * burstPalmRadius
 
      let trailStart = trailEndPoint
 
      var previousPoints: [SIMD2<Float>] = []
 
      for j in 0...segments {
        let t = Float(j) * tIncrement
 
        let currentPoint = SIMD2<Float>(
          trailStart.x + initialVelocity.x * t,
          trailStart.y + initialVelocity.y * t - (gravity * t * t) / 2
        )
 
        let tangent = SIMD2<Float>(
          initialVelocity.x,
          initialVelocity.y - gravity * t
        )
        let normalizedTangent = normalize(tangent)
 
        let normal = SIMD2<Float>(-normalizedTangent.y, normalizedTangent.x)
 
        let offset = -trailWidth / 2
 
        let p1 = currentPoint + normal * offset
        let p2 = currentPoint - normal * offset
 
        guard j > 0 else {
          previousPoints = [p1, p2]
          continue
        }
 
        let p_current = 0.5 * Float(j) / Float(segments) + 0.5
        let p_previous = 0.5 * Float(j - 1) / Float(segments) + 0.5
 
        appendVertex(previousPoints[0], to: &vertexBuffer)
        appendVertex(p1, to: &vertexBuffer)
        appendVertex(previousPoints[1], to: &vertexBuffer)
 
        progressBuffer.appendRaw(p_previous)
        progressBuffer.appendRaw(p_current)
        progressBuffer.appendRaw(p_previous)
 
        appendVertex(previousPoints[1], to: &vertexBuffer)
        appendVertex(p1, to: &vertexBuffer)
        appendVertex(p2, to: &vertexBuffer)
 
        progressBuffer.appendRaw(p_previous)
        progressBuffer.appendRaw(p_current)
        progressBuffer.appendRaw(p_current)
 
        previousPoints = [p1, p2]
      }
    }
  }
}

Don't forget to invoke this method in the drawing call.

extension FireworkScene {
  func draw(
    vertexBuffer: inout ConstantBuffer<Float>,
    progressBuffer: inout ConstantBuffer<Float>
  ) {
    drawLaunchTrail(
      vertexBuffer: &vertexBuffer,
      progressBuffer: &progressBuffer
    )
    drawBurstPalm(
      vertexBuffer: &vertexBuffer,
      progressBuffer: &progressBuffer
    )
  }
}

Glowing Trails

The firework we created is missing a surrounding glowing. To achieve this effect, we need to use blurring, it does a good job of simulating the glow. Metal Performance Shaders provides a standard type MPSImageGaussianBlur, which suits us quite well. It accepts several textures as input - one as a data source, the second as a result storage. We've defined earlier sceneTexture and glowTexture, so let's utilize them now. Also change the previous usage of drawable.texture to the sceneTexture so the firework will be drawn to this texture.

If you run the code now, you won't see any result. This is because we don't write anything to the final texture. To make it work, we need to collect information from both textures - sceneTexture and glowTexture - and write them to the drawable.

import MetalPerformanceShaders
 
func draw(in view: MTKView) {
  ...
 
  if
    let renderPassDescriptor = sceneRenderPassDescriptor(sceneTexture), // <-- attention here
    let renderEncoder = commandBuffer.makeRenderCommandEncoder(
      descriptor: renderPassDescriptor
    )
  {
    ...
  }
 
  let kernel = MPSImageGaussianBlur(
    device: device,
    sigma: 40.0
  )
  kernel.encode(
    commandBuffer: commandBuffer,
    sourceTexture: sceneTexture,
    destinationTexture: glowTexture
  )
 
  commandBuffer.present(drawable)
  commandBuffer.commit()
}

We define additional properties for this post-processing composition state.

final class FireworkViewController: ViewController {
  ...
 
  private var sceneBuffer: (any MTLBuffer)!
  private var compositePipelineState: (any MTLRenderPipelineState)!
}

The composition stage affects the entire screen. To cover it, we need to define only two triangles that will form a rectangle covering the screen.

func buildBuffers() {
  ...
 
  let quadVertices: [Float] = [
    -1.0,   1.0,  0.0,  0.0,
     1.0,   1.0,  1.0,  0.0,
    -1.0,  -1.0,  0.0,  1.0,
     1.0,  -1.0,  1.0,  1.0,
  ]
 
  sceneBuffer = device.makeBuffer(
    bytes: sceneVertices,
    length: MemoryLayout<Float>.stride * sceneVertices.count,
    options: []
  )
}

And add some more template code to configure the state descriptor.

func buildPipelineStates() {
  ...
 
  let compositePipelineDescriptor = MTLRenderPipelineDescriptor()
  compositePipelineDescriptor.vertexFunction = library.makeFunction(
    name: "Firework::vertexComposition"
  )
  compositePipelineDescriptor.fragmentFunction = library.makeFunction(
    name: "Firework::fragmentComposition"
  )
  compositePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm
 
  do {
    compositePipelineState = try device.makeRenderPipelineState(
      descriptor: compositePipelineDescriptor
    )
  } catch {
    fatalError("Failed to create composite pipeline state: \(error)")
  }
}

Back to the draw method, we create a new render pass with the drawable which will be displaying the final result on the screen. Also here we pass the two texture instances we defined earlier as fragment shader parameters.

func draw(in view: MTKView) {
  ...
 
  if
    let renderPassDescriptor = sceneRenderPassDescriptor(drawable.texture),
    let renderEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor)
  {
    renderEncoder.setRenderPipelineState(compositePipelineState)
    renderEncoder.setVertexBuffer(sceneBuffer, offset: 0, index: 0)
    renderEncoder.setFragmentTexture(sceneTexture, index: 0)
    renderEncoder.setFragmentTexture(glowTexture, index: 1)
    renderEncoder.drawPrimitives(
      type: .triangleStrip,
      vertexStart: 0,
      vertexCount: 4
    )
    renderEncoder.endEncoding()
  }
 
  commandBuffer.present(drawable)
  commandBuffer.commit()
}

The CompositionOut structure holds texCoord which is used to sample each pixel from the incoming textures. By sampling we get color for each pixel. The result of the composition fragment is the sum of colors from both textures, which is like overlaying two drawings with transparent backgrounds on top of each other to make one. Here bloomIntensity helps to definy how much the bloom texture affects the final result.

namespace Firework {
  ...
 
  struct CompositionOut {
    float4 position [[position]];
    float2 texCoord;
  };
 
  vertex CompositionOut vertexComposition(
    constant float4* vertexData [[ buffer(0) ]],
    uint vertexID [[vertex_id]]
  ) {
    CompositionOut out;
 
    float2 position = vertexData[vertexID].xy;
    float2 texCoord = vertexData[vertexID].zw;
 
    out.position = float4(position, 0.0, 1.0);
    out.texCoord = texCoord;
 
    return out;
  }
 
  fragment float4 fragmentComposition(
    CompositionOut in [[stage_in]],
    texture2d<float> sceneTexture [[texture(0)]],
    texture2d<float> bloomTexture [[texture(1)]]
  ) {
    constexpr sampler textureSampler (mag_filter::linear, min_filter::linear);
    float4 sceneColor = sceneTexture.sample(textureSampler, in.texCoord);
    float4 bloomColor = bloomTexture.sample(textureSampler, in.texCoord);
 
    float bloomIntensity = 3.0;
    return sceneColor + bloomColor * bloomIntensity;
  }
}

Now the firework has this surrounding bloom.

To make the animation even more natural, we can use the easing function. You can find even more examples of easing functions here.

func easeOutCubic(_ x: Float) -> Float {
  1 - pow(1 - x, 3)
}

Wrap the existing time calculation inside the draw call with the easing function.

let t = easeOutCubic(elapsedTime / duration)

Instanced Fireworks

Finally we are ready to draw multiple fireworks. First let's define infrastructure to move then to different locations.

You may be familiar with CGAffineTransform, which allows you to manipulate the geometry of views. For our purposes matrix_float4x4 is better suited. Here we define a convenient initialiser that sets the values for the affine transformation. The way we do this is to move objects along the x and y axes.

extension matrix_float4x4 { 
  init(translationX x: Float, y: Float) {
    self.init(columns: (
      SIMD4<Float>(1, 0, 0, 0),
      SIMD4<Float>(0, 1, 0, 0),
      SIMD4<Float>(0, 0, 1, 0),
      SIMD4<Float>(x, y, 0, 1)
    ))
  }
}

Next add the transform parameter to the Uniform structure.

struct FireworkScene {
  struct Uniform {
    let color: SIMD4<Float>
    let transform: simd_float4x4
    let time: Float
  }
 
  ...
}

Replace the single time calculation with these three. Note that for t2 and t3 timings we subtract a little from the elapsedTime simulating a delay.

let t1 = easeOutCubic(elapsedTime / duration)
var uniform1 = FireworkScene.Uniform(
  color: SIMD4<Float>(208.0 / 255.0, 80.0 / 255.0, 111.0 / 255.0, 1.0),
  transform: .init(translationX: 0.1, y: 0.0),
  time: t1
)
uBufferWrapper.write(value: &uniform1)
 
let t2 = easeOutCubic((elapsedTime - 0.2) / duration)
var uniform2 = FireworkScene.Uniform(
  color: SIMD4<Float>(100.0 / 255.0, 167.0 / 255.0, 230.0 / 255.0, 1.0),
  transform: .init(translationX: -0.1, y: 0.0),
  time: t2
)
uBufferWrapper.write(value: &uniform2)
 
let t3 = easeOutCubic((elapsedTime - 0.4) / duration)
var uniform3 = FireworkScene.Uniform(
  color: SIMD4<Float>(90.0 / 255.0, 98.0 / 255.0, 198.0 / 255.0, 1.0),
  transform: .init(translationX: 0.0, y: 0.2),
  time: t3
)
uBufferWrapper.write(value: &uniform3)

Also find the call of drawPrimitives and increase the instanceCount up to 3.

renderEncoder.drawPrimitives(
  type: .triangle,
  vertexStart: 0,
  vertexCount: vertexCount,
  instanceCount: 3
)

In Metal code update the Uniform by adding the transform property.

namespace Firework {
  struct Uniform {
    packed_float4 color;
    float4x4 transform;
    float time;
  };
 
  ...
}

And in the vertex shader for the scene, update the position calculation by multiplying the buffer position by the transform matrix.

vertex VertexOut vertexScene(
  uint vid [[ vertex_id ]],
  uint iid [[ instance_id ]],
  constant packed_float4* position [[ buffer(0) ]],
  constant float* progress [[ buffer(1) ]],
  constant Uniform* uniform [[ buffer(2) ]]
) {
  VertexOut out;
 
  out.position = uniform[iid].transform * float4(position[vid]);
  out.color    = uniform[iid].color;
  out.progress = progress[vid];
  out.time = uniform[iid].time;
 
  return out;
}

Conclusion

When I started building this experiment, I had no idea about instance rendering and had some difficulty with passing uniforms. After many attempts I got the desired result, but even it has many points of improvement and optimisation. In my opinion, this take describes the learning process very well. In order to master something, you have to fail many times. And that happens when you try something on your own. So I encourage you to try this code on your own and see what you can come up with.

Here is gist with the final code.

And some references that helped me:

See you in the next experiments! 🙌