Growing plants with code

Objectives:

In this experiment, I wanted to learn how to grow a branch along a path and generate procedural branchlets. This was quite a large experiment that accomplished multiple things:

  • Making a mesh "grow" along a path drawn in Blender.
  • Generating new branchlets procedurally that will expand from the main branch.
  • Adding leaves that will grow along with the branchlets.
  • Make everything react to the scroll.

TL;DR

To achieve the "growing" effect I started of with 3D paths guides for branch growth. The branches are rendered as cylinders, morphed to align with these paths. To make this task efficiently. I computed some key path values beforehand called PathVertices. These include details about the distance, rotation and position of each path point. Then, I used those values in the vertex shader to transform the vertices of the cylinder.

Getting the base model

I started with a base plant model from sketchfab. After getting the model into Blender, I got rid of the original branches and added Bezier Curves to guide the growth later.

For proper exporting and parsing later, I turned those curves into geometry and renamed them as Branch.[number]. Also, I included a leaf and a stick model for later use in the experiment. The model was exported as plant.glb.

Main setup

I used zustand and leva to create a store and controls to manage the plant's state. This store keeps track of values like:

  • grow: number -> The amount of growth that the plant will have.
  • debugCamera: boolean -> Whether to use an external camera and helpers.
  • debugGrid: boolean -> Whether to show the debug lines or not.
import { useControls } from "leva";
import { create } from "zustand";
import { shallow } from "zustand/shallow";

interface ConfigStore {
  grow: number;
  debugCamera: boolean;
  debugGrid: boolean;
}

export const useConfigStore = create<ConfigStore>(() => ({
  grow: 10,
  debugCamera: false,
  debugGrid: false,
}));

/** Should be used only once */
export const useConfigControls = () => {
  useControls(() => ({
    // disabled in prod
    grow: {
      value: 0,
      min: 0,
      max: 1.3,
      step: 0.001,
      onChange: (value) => {
        useConfigStore.setState({ grow: value });
      },
    },
    debugCamera: {
      value: false,
      onChange: (value) => {
        useConfigStore.setState({ debugCamera: value });
      },
    },
    debugGrid: {
      value: false,
      onChange: (value) => {
        useConfigStore.setState({ debugGrid: value });
      },
    },
  }));

  return;
};

export const useConfig = () => {
  return useConfigStore(
    (state) => ({
      grow: state.grow,
      debug: state.debug,
    }),
    shallow
  );
};

Then, I created a main component that imports the Branches and Pot components and exports the GLTF nodes that will be used to parse the model.

import { LineSegments, Mesh } from "three";
import type { GLTF } from "three-stdlib";
import { Branches } from "./branches";
import { Pot } from "./pot";

export interface PlantGLTF extends GLTF {
  nodes: {
    pot: Mesh;
    stick: Mesh;
    Branch: LineSegments;
    leaf: Mesh;
  };
}

export const Plant = () => {
  return (
    <group position={[0, 0, 0]}>
      <Branches />
      <Pot />
    </group>
  );
};

I also added a background stage to the scene:

Creating the main branches

To add the branches to our scene. I first needed to parse the data of the .glb file. Because we transformed our Bezier curves into Meshes, the data will be stored as LineSegments.

import { GLTFLoader } from "three-stdlib";
import { useLoader } from "@react-three/fiber";
import { PlantGLTF } from ".";

export const Branches = () => {
  // Load the model using the GLTFLoader
  const plantModel = useLoader(
    GLTFLoader,
    "/experiment-shaders-plants-assets/plant.glb"
  ) as unknown as PlantGLTF;

  // Parse the model and find the LineSegments nodes that start with "Branch"
  const branches = useMemo(() => {
    return Object.values(plantModel.nodes).filter(
      (node) => node.name.startsWith("Branch") && node.type === "LineSegments"
    ) as LineSegments[];
  }, [plantModel]);
};

Next up, I put together some uniforms for all materials to use later. I also added a useEffect that will update the uniforms when the grow value changes.

import { useEffect, useMemo } from "react";
import { Uniforms, useUniforms } from "../utils/uniforms";
import { useConfig } from "../utils/use-config";

// Create some default uniforms
const branchUniforms = {
  progress: 0,
  branchRadius: 0.005,
  branchGrowOffset: 0.1,
};

// Export the typed uniforms
export type BranchUniforms = Uniforms<typeof branchUniforms>;

export const Branches = () => {
  /** ...previous code... */

  // Add a hook to create uniforms
  const [uniforms, setUniforms] = useUniforms(branchUniforms);

  // Get the grow value from the store
  const { grow } = useConfig();

  // The grow value will be used to update the progress uniform
  useEffect(() => {
    setUniforms({
      progress: grow,
    });
  }, [grow]);
};

Finally, I looped through the branches array and returned a <Branch /> component. Here's how the final code looks:

import { useLoader } from "@react-three/fiber";
import { GLTFLoader } from "three-stdlib";
import { PlantGLTF } from ".";
import { useEffect, useMemo } from "react";
import { LineSegments } from "three";
import { Uniforms, useUniforms } from "../utils/uniforms";
import { useConfig } from "../utils/use-config";
import { Branch } from "./branch";

const branchUniforms = {
  progress: 0,
  branchRadius: 0.005,
  branchGrowOffset: 0.1,
};

export type BranchUniforms = Uniforms<typeof branchUniforms>;

export const Branches = () => {
  const plantModel = useLoader(
    GLTFLoader,
    "/experiment-shaders-plants-assets/plant.glb"
  ) as unknown as PlantGLTF;

  const branches = useMemo(() => {
    return Object.values(plantModel.nodes).filter(
      (node) => node.name.startsWith("Branch") && node.type === "LineSegments"
    ) as LineSegments[];
  }, [plantModel]);

  const [uniforms, setUniforms] = useUniforms(branchUniforms);

  const { grow } = useConfig();

  useEffect(() => {
    setUniforms({
      progress: grow,
    });
  }, [grow]);

  return (
    <group>
      {branches.map((branch, i) => (
        <Branch uniforms={uniforms} segments={branch} key={i} branchlets={15} />
      ))}
    </group>
  );
};

If we turn on the debug for the branches, we can see the path that they will follow:

The <Branch/> component

The Branch component will do two things. It creates a mesh that grows following the segments path, guided by uniforms.growFactor, and it also adds smaller, procedural branches known as branchlets.

To kick things off, we'll set up an interface for our component:

export interface BranchProps {
  segments: LineSegments;
  uniforms: BranchUniforms;
  branchlets: number;
}

Now, let's dive into the component. We begin by loading the branch texture with the useTexture hook.

export const Branch = ({ segments, uniforms, branchlets }: BranchProps) => {
  const branchMap = useTexture(
    "/experiment-shaders-plants-assets/branch-texture.jpg",
    (t: Texture) => {
      t.wrapT = RepeatWrapping;
      t.wrapS = MirroredRepeatWrapping;
    }
  );
};

I wanted to generate the mesh procedurally, so I created a getBranchMesh function. It takes the segments path and the uniforms and returns a branchMesh and a branchPath.

export const Branch = ({ segments, uniforms, branchlets }: BranchProps) => {
  const branchMap = useTexture(
    "/experiment-shaders-plants-assets/branch-texture.jpg",
    (t: Texture) => {
      t.wrapT = RepeatWrapping;
      t.wrapS = MirroredRepeatWrapping;
    }
  );

  const { branchMesh, branchPath } = useMemo(() => {
    return getBranchMesh(segments.clone(true), uniforms, branchMap);
  }, [segments, uniforms, branchMap]);

  return <primitive object={branchMesh} />;
};

Let's go just one more step into this rabbit hole and see what the getBranchMesh function does.

getBranchMesh

This function handles the generation of the branch mesh. The tricky part? The mesh has to track the path segments and grow with them. To achieve this, I first needed to create a cylinder geometry that will be used as a base for the mesh.

/** Transform a lineSegment into a branch mesh */
export const getBranchMesh = (
  branch: LineSegments,
  branchUniforms: BranchUniforms,
  texture: Texture
) => {
  const branchResolution = 20;
  const noramlizedCylinder = new CylinderGeometry(
    1, // radiusTop
    1, // radiusBottom
    1, // height
    branchResolution, // radialSegments
    numVertices * 2 // heightSegments
  );
};

This is the result of the cylinder geometry:

Let's take some example vertex coordinates of that geometry:

The idea is to Shift these cylinder vertices along the path, and we'll use a shader to do just that.

Let's take one of the vertices and see how it would be transformed and call it p:

Next, we gotta figure out where this vertex would end up on the path. Imagine stretching the path into a straight line to make this easier:

Now we can see that the p point would land between the second and third point of the path. Let's call those points i and j:

Knowing what vertices are i and j (previous and next vertices), we can start transforming the p point. The first step is to flatten all the vertices of the cylinder geometry to the xz plane:

Then, the point p will be rotated to match the direction j-i.

Finally, we can lerp between the i and j points to find the final position of p and translate it:

At first, I only used the raw vertex data in the branch shader, but I found that it was not very efficient. So I decided to pre-calculate some values in javascript and sent those over to the shader as uniforms.

Growing meshes

I needed a way to efficiently transform the vertices of the cylinder geometry. To do this, I created a funciton that loops over each vertex of the segments variable, and calculates the following variables:

To transform the cylinder geometry's vertices efficiently, I created up a function that loops through each vertex of the segments and calculates the following variables:

// position of this point
position: Vector3;
// distance form this point to the previous one
distance: number;
// direction from this point to the next
direction: Vector3;
// rotation from the prev point to this one
rotation: Quaternion;
// rotation from the first point to this one
addedQuaternion: Quaternion;

Here is a visual representation of the variables:

Note: The addedQuaternion is calcualted by multiplying all the rotations from the previous points.

I called this interface PathVertex. Then I also created PathVertices which also has the totalDistance calculated. To do this I created the verticesFromLineSegment function:

interface PathVertex {
  // position of this point
  position: Vector3;
  // distance form this point to the previous one
  distance: number;
  // direction from this point to the next
  direction: Vector3;
  // rotation from the prev point to this one
  rotation: Quaternion;
  // rotation from the first point to this one
  addedQuaternion: Quaternion;
}

interface PathVertices {
  pathVertices: PathVertex[];
  totalDistance: number;
  numVertices: number;
}

const verticesFromLineSegment = (branch: LineSegments): PathVertices => {...}

To see the complete implementation of this function, check the source code:

/src/Plant/helpers/path-vertex.ts

Branch material

Here's a quick summary of our progress:

  • We've got the getBranchMesh function that handles the creation of a cylinder geometry and the transforming the vertices.
  • There is a segments path that will be used to generate the branch mesh.
  • We precalculated the pathVertices to transform the vertices of the mesh.

Now we can start creating the branch shader. First, we need to create the ShaderMaterial:

export const getBranchMesh = (
  branch: LineSegments,
  branchUniforms: BranchUniforms,
  texture: Texture
) => {
  const branchPath = verticesFromLineSegment(branch);

  const { pathVertices, totalDistance, numVertices } = branchPath;

  /** This material will transform the cylinder geometry to follow the path */
  const branchMaterial = new ShaderMaterial({
    name: branch.name + "material",
    vertexShader: branchVertexShader,
    fragmentShader: branchFragmentShader,
    glslVersion: GLSL3,
    defines: {
      NUM_VERTICES: numVertices,
    },
    uniforms: {
      map: { value: texture },
      pathVertices: {
        value: pathVertices,
      },
      totalDistance: { value: totalDistance },
      ...branchUniforms,
    },
  });

  const branchResolution = 20;
  const noramlizedCylinder = new CylinderGeometry(
    1,
    1,
    1,
    branchResolution,
    numVertices * 2
  );
  const branchMesh = new Mesh(noramlizedCylinder, branchMaterial);

  return {
    branchMesh,
    branchPath,
    position: branch.position,
    rotation: branch.rotation,
  };
};

Vertex shader

The key of the mesh transform is on the vertex shader. We'll start by setting up some basic structures:

struct PathVertex {
  vec3 position;
  float distance;
  vec3 direction;
  // quaternion rotation
  vec4 rotation;
  vec4 addedQuaternion;
};

struct PathPos {
  vec3 position;
  vec3 direction;
  vec4 rotation;
};

Now, we need to define the uniforms.

uniform PathVertex pathVertices[NUM_VERTICES];
uniform float totalDistance;
uniform float progress;
uniform float branchRadius;
uniform float branchGrowOffset;

Next, some varyings to pass data between the vertex and fragment shaders:

varying vec2 vUv;
varying vec3 worldPos;
varying vec3 localPos;
varying float targetFactor;
varying float growFactorRaw;
varying float growFactor; // clamped

getPositionOnPath helper

This function calculates the position of a point on the path. It uses a targetFactor — a number between 0 and 1 — and gives back a PathPos with the position, direction, and rotation of the point.

PathPos getPositionOnPath(float percentage) {

  percentage = clamp(percentage, 0.0, 1.0);

  // Calculate the target distance along the path
  float targetDistance = percentage * totalDistance;

  // Find the index of the vertices
  // indexPrev ------> iNext
  int iNext = 1;
  float traveledDistance = pathVertices[1].distance;
  while (traveledDistance < targetDistance) {
    if (iNext == NUM_VERTICES - 1) {
      // reached the end of the path
      break;
    }
    iNext++;
    traveledDistance += pathVertices[iNext].distance;
  }
  int iPrev = max(0, iNext - 1);

  // Get the two adjacent vertices
  vec3 posPrev = pathVertices[iPrev].position;
  vec3 posNext = pathVertices[iNext].position;

  float distancePrevToNext = pathVertices[iNext].distance;

  // Calculate the interpolation factor based on distances
  // 0 ------ tDist ------ offsetDist
  float offsetDist = traveledDistance - distancePrevToNext;
  float tDist = targetDistance - offsetDist;
  float t = tDist / distancePrevToNext;

  // Interpolate the position
  vec3 position = mix(posPrev, posNext, t);

  // Calculate the direction as the normalized direction between the two vertices
  vec3 direction = mix(pathVertices[iPrev].direction, pathVertices[iNext].direction, t);

  //mix quaterions
  vec4 rotation = mix(pathVertices[iPrev].addedQuaternion, pathVertices[iNext].addedQuaternion, t);

  return PathPos(
    position,
    direction,
    rotation
  );
}

Final vertex shader

Here's the full code for the vertex shader.

${pathStructs} // include of path structs

${pathUniforms} // include of uniforms

varying vec2 vUv;
varying vec3 worldPos;
varying vec3 localPos;
varying float targetFactor;
varying float growFactorRaw;
varying float growFactor; // clamped

${rotate} // include of some rotation helpers
${getPositionOnPath} // include of path helper

float getGrowFactor(float p) {
  float totalLenght = totalDistance * p;
  float currentLenght = localPos.y * totalLenght;
  float growEnd = totalLenght - branchGrowOffset;

  float growFactor = (growEnd - currentLenght) / branchGrowOffset + 1.;
  return growFactor;
}

void main() {
  float clampedProgress = clamp(progress, 0.0, 1.0);
  localPos = position + vec3(0.0, 0.5, 0.0);
  targetFactor = localPos.y;

  // calculate grow factor
  growFactorRaw = getGrowFactor(clampedProgress);
  growFactor = clamp(growFactorRaw, 0.0, 1.0);
  float branchSize = branchRadius * growFactor;

  // move vertices to y = 0
  vec3 targetPos = position * vec3(branchSize, 0.0, branchSize);

  //translate to path

  // rotate the Y axis to the direction
  targetPos = qtransform(pathPosition.rotation, targetPos);

  // move to pos
  targetPos += pathPosition.position;


  vUv = uv;
  worldPos = (modelMatrix * vec4(targetPos, 1.0)).xyz;
  gl_Position = projectionMatrix * modelViewMatrix * vec4(targetPos, 1.0);
}

Let's unpack what's going on in this shader:

// Get position on path
PathPos pathPosition = getPositionOnPath(targetFactor * clampedProgress);
// Flatten cylinder to xz plane and scale to branchSize
vec3 targetPos = position * vec3(branchSize, 0.0, branchSize);
// rotate the Y axis to the direction of the
// segment where the vertex should land
targetPos = qtransform(pathPosition.rotation, targetPos);
// translate
targetPos += pathPosition.position;

The result of this shader is a mesh that will grow along the path:

By setting the grow uniform to 0.5, we can see the branch move to the midpoint of the path:

Branchlets

Branchlets are secondary branches that will grow from the main branch. I didn't want to model each branchlet in Blender, so I decided to generate them procedurally.

To generate the branchlet, we start by generating random points that stretch out from each branch.

Let's create a function called getBranchletVertices that will take a pathVertices array and a t value between 0 and 1 and will return a new set of pathVertices and a position where the branchlet will start.

The t value will indicate where the branchlet will start on the main branch. For example, if t is 0.5, the branchlet will start in the middle of the branch. It is important to know this value for a couple of reasons:

  • We want the branchlet to grow in the same "direction" of the branch to give a more natural result.
  • We need to know the initial position of the branchlet to translate it to the correct position.
  • We need to calculate the correct growFactor, we don't want the branchlet to start growing before the main branch gets to that point.
/** Generates a branchlet that will start at the same direction/position of the branch at t */
export const getBranchletVertices = (pathVertices: PathVertex[], t: number) => {
  const branchletVertices: Vector3[] = [];

  const { direction, position } = getPathVertex(pathVertices, t);

  const currentDireciton = direction.clone();
  const currentPosition = new Vector3(0, 0, 0);
  const randomRotation = new Quaternion();

  // first vertex
  branchletVertices.push(currentPosition.clone());

  const randomFactor = 0.1;
  const numVertices = 20;
  const edgeLength = 0.01;

  for (let i = 0; i < numVertices - 1; i++) {
    // rotate direction
    randomRotation.setFromEuler(
      new Euler(
        (Math.random() - 0.5) * 2 * Math.PI * randomFactor,
        (Math.random() - 0.5) * 2 * Math.PI * randomFactor,
        (Math.random() - 0.5) * 2 * Math.PI * randomFactor
      )
    );

    currentDireciton.applyQuaternion(randomRotation);

    // smoot Y over time
    const branchProgress = (i + 1) / numVertices;
    const growOffset = 0.7;
    const cap = 0.2;
    const yFactor =
      Math.cos(branchProgress * Math.PI * growOffset) * cap + (1 - cap);
    currentDireciton.y = currentDireciton.y * yFactor;
    currentDireciton.normalize();

    // move vertex
    currentPosition.add(currentDireciton.clone().multiplyScalar(edgeLength));

    // add vertex
    branchletVertices.push(currentPosition.clone());
  }

  return {
    pathVertices: vectorsToPathVertices(branchletVertices),
    position,
  };
};

Just like we did with the branch mesh, now we need to create a getBranchletMesh function:

export const getBranchletMesh = (
  path: PathVertices,
  t: number,
  uniforms: BranchUniforms,
  texture: Texture
): Mesh => {
  const branchletGeometry = new CylinderGeometry(
    1,
    1,
    1,
    10,
    path.numVertices * 2
  );
  const branchletMaterial = new ShaderMaterial({
    vertexShader: branchletVertexShader,
    fragmentShader: branchletFragmentShader,
    glslVersion: GLSL3,
    defines: {
      NUM_VERTICES: path.numVertices,
    },
    uniforms: {
      ...uniforms,
      map: { value: texture },
      pathVertices: {
        value: path.pathVertices,
      },
      tStart: { value: t + 0.1 },
      tEnd: { value: t + 0.3 },
      totalDistance: { value: path.totalDistance },
    },
  });
  const branchletMesh = new Mesh(branchletGeometry, branchletMaterial);

  return branchletMesh;
};

Branchlet material

The shader for teh branchlet material is similar to the branch material. The main difference is that we are using the tStart and tEnd uniforms to indicate where the branchlet will start and end.

${pathStructs}

${pathUniforms}

uniform float tStart;
uniform float tEnd;

varying vec2 vUv;
varying vec3 worldPos;
varying vec3 localPos;
varying float targetFactor;
varying float growFactorRaw;
varying float growFactor; // clamped
varying float branchletProgress;

// add utils
${rotate}
${getPositionOnPath}
${valueRemap}

float getGrowFactor(float bProgress) {
  float branchletGrowOffset = branchGrowOffset * 0.2;
  float totalLenght = totalDistance * bProgress;

  float currentLenght = localPos.y * totalLenght;
  float growEnd = totalLenght - branchletGrowOffset;

  float growFactor = (growEnd - currentLenght) / branchletGrowOffset + 1.;
  return growFactor;
}

void main() {
  localPos = position + vec3(0.0, 0.5, 0.0);
  targetFactor = localPos.y;

  //translate to path
  branchletProgress = valueRemap(progress, tStart, tEnd, 0.0, 1.0);
  branchletProgress = clamp(branchletProgress, 0.0, 1.0);
  PathPos pathPosition = getPositionOnPath(targetFactor * branchletProgress);

  // calculate grow factor
  growFactorRaw = getGrowFactor(branchletProgress);
  growFactor = clamp(growFactorRaw, branchletProgress > 0.1 ? 0.5 : 0., 1.0);
  float branchSize = branchRadius * 0.5 * growFactor;

  // move vertices to y = 0
  vec3 targetPos = position * vec3(branchSize, 0.0, branchSize);

  // rotate the Y axis to the direction
  targetPos = qtransform(pathPosition.rotation, targetPos);

  // move to pos
  targetPos += pathPosition.position;


  vUv = uv;
  worldPos = (modelMatrix * vec4(targetPos, 1.0)).xyz;
  gl_Position = projectionMatrix * modelViewMatrix * vec4(targetPos, 1.0);
}
out vec4 fragColor;

varying vec3 worldPos;
varying vec2 vUv;
varying float targetFactor;
varying vec3 localPos;
varying float growFactor;
varying float branchletProgress;

uniform float totalDistance;
uniform float progress;
uniform sampler2D map;

void main() {
  vec2 mapUv = vec2(vUv.x * 2.0, localPos.y * branchletProgress * totalDistance * 4.);
  vec3 colorMap = texture2D(map, mapUv).rgb;

  fragColor = vec4(colorMap, 1.0);
}

Observe how the branchlets start growing after the main branch. This is achieved by using the tStart and tEnd uniforms to shift the grow factor:

Branchlet component

Using the getBranchletMesh function, we can create a <Branchlet /> component that will take the pathVertices and t values and will return a <primitive /> component.

import { useMemo } from "react";
import {
  getBranchletMesh,
  getBranchletVertices,
} from "./helpers/branchlet-utils";
import { PathVertex } from "./helpers/path-vertex";
import { BranchUniforms } from "./branches";
import { Leaf } from "./leaf";
import { BufferGeometry, Line, LineBasicMaterial, Texture } from "three";
import { DebugLine } from "./debug-line";

interface BranchletProps {
  pathVertices: PathVertex[];
  uniforms: BranchUniforms;
  t: number;
  texture: Texture;
}

export const Branchlet = ({
  pathVertices,
  uniforms,
  t,
  texture,
}: BranchletProps) => {
  const { branchletMesh, branchletPath, position } = useMemo(() => {
    const branchletVertices = getBranchletVertices(pathVertices, t);
    const branchletMesh = getBranchletMesh(
      branchletVertices.pathVertices,
      t,
      uniforms,
      texture
    );
    return {
      branchletMesh,
      branchletPath: branchletVertices.pathVertices,
      position: branchletVertices.position,
    };
  }, [t, texture]);

  const branchletLineSegments = useMemo(() => {
    const branchletLineSegments = new Line(
      new BufferGeometry().setFromPoints(
        branchletPath.pathVertices.map((v) => v.position)
      ),

      new LineBasicMaterial({
        color: "blue",
        linewidth: 1,
        depthTest: false,
      })
    );
    return branchletLineSegments;
  }, [branchletPath]);

  return (
    <group position={position}>
      <primitive object={branchletMesh} />
    </group>
  );
};

Now, we can update the branch.tsx file to generate an array of random numbers. Those numbers will indicate where the branchlet will grow from the main branch.

import {
  LineSegments,
  MirroredRepeatWrapping,
  RepeatWrapping,
  Texture,
} from "three";
import { BranchUniforms } from "./branches";
import { useMemo } from "react";
import { getBranchMesh } from "./helpers/get-branch-mesh";
import { Branchlet } from "./branchlet";
import { useTexture } from "@react-three/drei";
import { DebugLine } from "./debug-line";

export interface BranchProps {
  segments: LineSegments;
  uniforms: BranchUniforms;
  branchlets: number;
}

export const Branch = ({ segments, uniforms, branchlets }: BranchProps) => {
  const branchMap = useTexture(
    "/experiment-shaders-plants-assets/branch-texture.jpg",
    (t: Texture) => {
      t.wrapT = RepeatWrapping;
      t.wrapS = MirroredRepeatWrapping;
    }
  );

  const { branchMesh, branchPath, position, rotation } = useMemo(() => {
    return getBranchMesh(segments.clone(true), uniforms, branchMap);
  }, [segments, uniforms, branchMap]);

  /* Generate the random points */
  const branchletsArr = useMemo(() => {
    const ts = Array.from(Array(branchlets).keys()).map(() =>
      Math.min(Math.random(), 0.9)
    );
    // add a final branchlet at the end
    ts.push(0.9);
    return ts;
  }, [branchlets]);

  return (
    <group position={position} rotation={rotation}>
      <primitive object={branchMesh} />
      {/* Map over the points*/}
      {branchletsArr.map((t, i) => (
        <Branchlet
          pathVertices={branchPath.pathVertices}
          uniforms={uniforms}
          texture={branchMap}
          t={t}
          key={i}
        />
      ))}
    </group>
  );
};

Branchelts result:

Leafs

On our plant model we have one lead geometry that we will use to create all leaves. I created a leaf node with useMemo.

export interface LeafProps {
  branchletPath: PathVertices;
  uniforms: BranchUniforms;
  /** Where in the main branch the branchlet starts */
  t: number;
}

export const Leaf = ({ branchletPath, uniforms, t }: LeafProps) => {
  const { debugLeaves, renderLeaves } = useConfig();

  /* Load model */
  const plantModel = useLoader(
    GLTFLoader,
    "/experiment-shaders-plants-assets/plant.glb"
  ) as unknown as PlantGLTF;

  const [leafUniforms, setLeafUniforms] = useUniforms({
    leafProgress: 0,
  });

  const modelNode = useMemo(() => {
    const leaf = plantModel.nodes.leaf.clone();
    const leafMaterial = leaf.material as MeshStandardMaterial;
    const leafTexture = leafMaterial.map.clone();
    leafTexture.colorSpace = "srgb-linear";

    /* Create material */
    leaf.material = new ShaderMaterial({
      side: DoubleSide,
      transparent: true,
      vertexShader: leafVertexShader,
      fragmentShader: leafFragmentShader,
      glslVersion: GLSL3,
      uniforms: {
        ...leafUniforms,
        leafTexture: {
          value: leafTexture,
        },
      },
    });

    return leaf;
  }, [plantModel, branchletPath, uniforms, t]);

  const helper = useMemo(() => {
    return new AxesHelper(0.1);
  }, []);
};

Then, inside a useEffect we will calculate the position and scale of the leaf based on the growFactor

export interface LeafProps {
  branchletPath: PathVertices;
  uniforms: BranchUniforms;
  /** Where in the main branch the branchlet starts */
  t: number;
}

export const Leaf = ({ branchletPath, uniforms, t }: LeafProps) => {
  const { debugLeaves, renderLeaves } = useConfig();

  const plantModel = useLoader(
    GLTFLoader,
    "/experiment-shaders-plants-assets/plant.glb"
  ) as unknown as PlantGLTF;

  const [leafUniforms, setLeafUniforms] = useUniforms({
    leafProgress: 0,
  });

  const modelNode = useMemo(() => {
    const leaf = plantModel.nodes.leaf.clone();
    const leafMaterial = leaf.material as MeshStandardMaterial;
    const leafTexture = leafMaterial.map.clone();
    leafTexture.colorSpace = "srgb-linear";

    leaf.material = new ShaderMaterial({
      side: DoubleSide,
      transparent: true,
      vertexShader: leafVertexShader,
      fragmentShader: leafFragmentShader,
      glslVersion: GLSL3,
      uniforms: {
        ...leafUniforms,
        leafTexture: {
          value: leafTexture,
        },
      },
    });

    return leaf;
  }, [plantModel, branchletPath, uniforms, t]);

  const helper = useMemo(() => {
    return new AxesHelper(0.1);
  }, []);

  useEffect(() => {
    if (!modelNode) return;

    const abortController = new AbortController();
    const signal = abortController.signal;
    const isCanceled = () => signal.aborted;

    const curve = new CatmullRomCurve3(
      branchletPath.pathVertices.map((v) => v.position)
    );

    let prevProgress: number | null = null;

    const raf = () => {
      if (isCanceled()) return;
      const currentProgress = uniforms.progress.value;
      if (prevProgress === currentProgress) {
        requestAnimationFrame(raf);
        return;
      }

      const tStart = t + 0.1;
      const tEnd = t + 0.3;
      let branchletProgress = valueRemap(currentProgress, tStart, tEnd, 0, 1);
      branchletProgress = clamp(branchletProgress, 0, 1);

      // move model along branchlet path
      const point = curve.getPointAt(branchletProgress);
      const startingDirection = curve.getTangentAt(0);
      const endDirection = curve.getTangentAt(1);
      const branchDirection = startingDirection.lerp(
        endDirection,
        branchletProgress
      );
      const branchCurrentDirection = curve.getTangentAt(branchletProgress);
      branchDirection.lerp(branchCurrentDirection, 0.2);
      modelNode.position.copy(point);
      helper.position.copy(point);

      // rotate model to face branch direction
      const leafDirection = new Vector3(1, 0, 0).normalize();
      const axis = new Vector3()
        .crossVectors(leafDirection, branchDirection)
        .normalize();
      const angle = Math.acos(leafDirection.dot(branchDirection));
      modelNode.quaternion.setFromAxisAngle(axis, angle);
      helper.quaternion.setFromAxisAngle(axis, angle);

      // scale
      modelNode.scale.setScalar(branchletProgress * 2);

      // update progress
      prevProgress = currentProgress;
      setLeafUniforms({ leafProgress: branchletProgress });

      requestAnimationFrame(raf);
    };
    requestAnimationFrame(raf);

    return () => {
      abortController.abort();
    };
  }, [uniforms, branchletPath, t, modelNode]);

  return (
    <group>
      {debugLeaves && <primitive object={helper} />}
      {renderLeaves && <primitive object={modelNode} />}
    </group>
  );
};

By turning on debug, we can observe how the helper follows the end of the branchlet:

Leaf material

The shaders for the leaf are very simple, beacause all the position and rotation calculations are done in the <Leaf /> component.

varying vec2 vUv;

void main() {
  vUv = uv;

  vec3 worldPos = (modelMatrix * vec4(position, 1.0)).xyz;
  gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);

}

Here I am using the leafTexture and leafProgress (remaped growFactor) to color the leaf.

out vec4 fragColor;

uniform sampler2D leafTexture;
uniform float leafProgress;

varying vec2 vUv;

void main() {
  vec4 colorMap = texture2D(leafTexture, vUv).rgba;
  if (colorMap.a < 0.6) discard;

  vec3 result = colorMap.rgb;
  vec3 green = vec3(0.2, 0.5, 0.2);

  result = mix(green, result, leafProgress);

  fragColor = vec4(result, colorMap.a);
}

Adding the leaves to the scene

Finally, we can add the <Leaf /> into the <Branchlet /> component:

export const Branchlet = ({
  pathVertices,
  uniforms,
  t,
  texture,
}: BranchletProps) => {
  /* ...branchlet source */

  return (
    <group position={position}>
      <primitive object={branchletMesh} />
      {/* Add the leaf component */}
      <Leaf t={t} branchletPath={branchletPath} uniforms={uniforms} />
    </group>
  );
};

Final result

You've reached the finish line – thanks for sticking around! If you liked this experiment, feel free to ⭐ the repository.

Take a look at the live demo and explore the source code.