Skip to content

Semantically-Aware Constrained Decoding for Code Generation

In my previous post, I shared Microsoft's AICI a very clever way of controlling the LLM inference process through portable, efficient, sandboxed WASM-based virtual machines.

Through a powerful abstraction like this, we could bring different stages of the generative pipeline of a Transformer-like architecture all the way up to the application layer.

Now that we are no longer limited by the static nature of prompts and hyperparameter configuration, we are able to apply advanced techniques for controlling the outputs of LLMs as a consumer.

Constrained Decoding

My focus for this post is constrained decoding, specifically for text-based LLMs, but this could be applicable for any sequence-to-sequence generative model and may be extended beyond that.

In simple terms, constrained decoding consists of reducing the number of possible next-token predictions to guide the model to produce a more desireable outcome. This is usually done by providing logit biases for specific tokens at each decoding step.

An example of this is OpenAI's "JSON Mode", which, among other techniques, applies a grammar-based constraint in the decoding phase of GPT to reduce occurrences of invalid JSON syntax outputs. This methodology is likely also applied for OpenAI's function calling feature, to ensure function calls are correct.

To illustrate the ideas of this post, I decided to use JavaScript, so I could express my ideas in a langage that is more accessible to a wider audience, but this could be implemented in any language.

In fact, the AICI project has a JavaScript interface JsCtrl which you can use to write these controllers.

This is how an AICI controller looks like:

async function main() {
    await $`Ultimate answer is to the life, universe and everything is `
    await gen({ regex: /\d\d/ })
}

start(main)

A likely output for this inference would be:

Ultimate answer is to the life, universe and everything is 42

For general-purpose code generation, however, I would lean towards a more precise approach for constraint decoding that explicitly presents the abstract syntax tree (AST). This will improve readability and maintainability of the code, and will also allow for more advanced features like context-awareness and type-safety.

In general, the idea is to use functions with tagged template literals to express the AST of the code that we want to generate and the constraints that we want to apply to the LLM.

For example, a function called membersOf would be used to express a MemberExpression tree in JavaScript, and a function called property would be used to express a Property tree that is a child of a MemberExpression tree.

/**
 * 
 * @example Syntax Tree
 * MemberExpression(
 *   object: Identifier('user')
 *   property: Identifier('name')
 * )
 * @example Outputs
 * user
 * user.name
 * user.email
 */
membersOf`user`(
    property`name`,
    property`email`
)

Here's a more practical example of how this could be used to express a constraint to match one of five possible JavaScript member expressions:

import { oneOf, membersOf, property } from "?"

async function main() {
    await oneOf(
        /**
         * @example Output
         * meaningOfLife.answer
         * meaningOfLife.question
         */
        membersOf`meaningOfLife`(
            property`answer`,
            property`question`
        ),
        /**
         * @example Output
         * meaningOfLife.author.name
         * meaningOfLife.author.born
         * meaningOfLife.author.died
         */
        membersOf`meaningOfLife.author`(
            property`name`,
            property`born`,
            property`died`
        )
    )
}

start(main)

The content in the tagged template literals would work the same way as AICI's $ function, but with the added benefit of being able to mix the AST with the grammar constraints.

We could also express other syntactic features like function calls, computed properties, or even variable references.

import { oneOf, membersOf, callOf, property, arg, computedProperty } from "?"

async function main() {
    await oneOf(
        membersOf`meaningOfLife`(
            property`answer`,
            property`question`
        ),
        membersOf`meaningOfLife.author`(
            property`name`,
            property`born`,
            property`died`
        ),
        membersOf`inventory.fruit`(
            computedProperty`0`,
            computedProperty`1`,
            computedProperty`2`
        ),
        membersOf`userMap`(
            computedProperty`userId`,
        ),
        /**
         * @example
         * window.localStorage.getItem(userMap[userId])
         */
        callOf`window.localStorage.getItem`(
            memberOf`userMap`(
                computedProperty`userId`
            )
        )
    )
}

start(main)

At a surface level, these could be expressed as a grammar, in fact here's how it would look like in EBNF:

RULE_0_A := "name" | "born" | "died"
RULE_0   := "meaningOfLife" "." [ "answer" "question" "author" [ "." RULE_0_1 ] ]
RULE_1_A := "0" | "1" | "2"
RULE_1   := "inventory" [ "." "fruit" "[" RULE_1_1 "]" ]
RULE_2   := "userMap" [ "[userId]" ]
RULE_3   := "window.localStorage.getItem" [ "(" "userMap" [ "[userId]" ] ")" ]
OUTPUT   := RULE_0 | RULE_1 | RULE_2 | RULE_3

However, expressing the AST in our decoding process has the added benefit of being able to express constraints that are also semantically significant.

Semantically-Aware Constrained Decoding

So far, we've only been talking about static semantic constraints. But what if we could go further? Could we handle dynamic semantic constraints?

Consider this example, which constrains the LLM to pick between two different API calls fetchUsersByEmail and postUser. This illustrates how str() could be used to express the AST of the code that we want to generate while also providing the LLM opportunities to fill in the blanks.

import { oneOf, callOf, obj, property, str } from "?"

const name = str()
const email = str()

/**
 * @example Output
 * fetch("https://api.example.com/[email protected]")
 */
const fetchUsersByEmail = callOf`fetch`(
    str`https://api.example.com/users?email=${email}`
)

/**
 * @example Output
 * fetch("https://api.example.com/users", {
 *   method: "POST",
 *   body: {
 *     name: "John Doe",
 *     email: "[email protected]"
 *   }
 * })
 */
const postUser = callOf`fetch`(
    str`https://api.example.com/users`,
    obj(
        property`method`("POST"),
        property`body`(
            obj(
                property`name`(name),
                property`email`(email)
            )
        )
    )
)

async function main() {
    await oneOf(
        fetchUsersByEmail,
        postUser
    )
}

start(main)

But generating literals is... well, too literal (and let's be honest, table stakes).

Of course the model would generate values for name and email, but in the real world we may want to use values that are present in the context of the application like variables and other expressions.

What if we used a special contextOf function to express the context of the application in a way that can be referenced by the LLM in a semantically-accurate way?

import { contextOf, callOf, obj, property, str, num, bool } from "?"

/**
 * @example Output
 * session.user.name
 * session.user.email
 * api.url
 */
const context = contextOf(
    membersOf`session.user`(
        property`name`(str),
        property`email`(str),
        property`age`(num)
    ),
    membersOf`api`(
        property`url`(str),
        property`public`(bool)
    )
)

const postUser = callOf`fetch`(
    str(),
    obj(
        property`method`("POST"),
        property`body`(
            obj(
                property`name`(str()),
                property`email`(str())
            )
        )
    )
)

async function main() {
    await context;
    await fetchUsersByEmail;
}

start(main)

This contextOf function is not a function that generates a context, but a function that references a context and makes it available to future generative steps.

We would implement our str handler to prompt the model to generate a string literal or reference a string from the context.

Speculative Implementation Details

Behind the scenes, a type mapping could be generated from the contextOf call, and the LLM would be guided to generate values that are semantically accurate and type safe.

string:
    - session.user.name
    - session.user.email
number:
    - session.user.age
boolean:
    - api.public
// ...
SomeType:
    - some.path.to.value

Practically-speaking, the real context would be provided by the application at runtime, but it wouldn't even be necessary to provide it at inference time.

Type-Safe Generative Variables

If we want to allow the LLM to output variable declarations, we could define a variable function.

Then, when generating arguments for function calls, we would constrain the LLM to use a variable that is already in the context or a literal value that satisfies the type of the argument.

Here's a simple arithmetic calculator used by contraining the model to only output variable declarations of explicitly number type and function invocations of explicitly the add, sub, mul, and div functions.

import { manyOf, callOf, returnOf, variable, id, num } from "?"

async function main() {
    /**
     * @example Prompt
     * I bought 3 apples (3.99 ea) and 4 oranges (4.39 ea).
     * 
     * How much did I spend?
     * @example Output
     * let apples = 3
     * let oranges = 4
     * let applePrice = 3.99
     * let orangePrice = 4.39
     * let appleTotal = mul(apples, applePrice)
     * let orangeTotal = mul(oranges, orangePrice)
     * let total = add(appleTotal, orangeTotal)
     * return total
     */
    await manyOf(
        variable`${id()}`(num()),
        variable`${id()}`(callOf`add`(num(), num()))
        variable`${id()}`(callOf`sub`(num(), num())),
        variable`${id()}`(callOf`mul`(num(), num())),
        variable`${id()}`(callOf`div`(num(), num())),
    )
    await returnOf(num())
}

The LLM would still need to generate the variable declarations and function invocations, but it would be guided to generate semantically-accurate and type-safe code.

Bonus: Refactoring

We could express the same code in a more readable, less-verbose way by abstracting the variable function.

import { manyOf, callOf, returnOf, variable, id, num } from "?"

const varFor = variable`${id()}`

async function main() {
    await manyOf(
        varFor(num()), // e.g. let apples = 3
        varFor(callOf`add`(num(), num())) // e.g. let totalFruit = add(apples, oranges)
        varFor(callOf`sub`(num(), num())), // e.g. let oneMinusTwo = sub(1, 2)
        varFor(callOf`mul`(num(), num())), // e.g. let applesTotal = mul(apples, applePrice)
        varFor(callOf`div`(num(), num())), // e.g. let orangeToAppleRatio = div(oranges, apples)
    )
    await returnOf(num()) // e.g. return orangeTotal
}

Conclusion

I know that this is purely a theoretical idea at this point, but I think it's a very interesting one that would allow us to get semantically-accurate and type-safe code generation from LLMs.

This works in theory, but there are many practical challenges to overcome. For example, how would we handle the context at runtime? How would we handle the type mapping? How would we know when to pause and resume the inference process to provide the context?

I deeply believe these are all solvable problems, and I'm excited to see where this goes.