Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to integrate structured decoding? #23

Closed
jimwhite opened this issue Mar 3, 2025 · 1 comment
Closed

How to integrate structured decoding? #23

jimwhite opened this issue Mar 3, 2025 · 1 comment

Comments

@jimwhite
Copy link

jimwhite commented Mar 3, 2025

Hi Will!

Fabulous work! I've been wanting to be able to train LLMs on code execution for long time but these ML framework innards have a steep learning curve.

In addition to format rewards I'd also like to use the structured decoding LLM stuff so that only correctly formatted output (including syntactically correct code) is generated. I know vLLM and Ollama can do it but I have no idea how to integrate that with GRPOTrainer verifiers.

https://www.bentoml.com/blog/structured-decoding-in-vllm-a-gentle-introduction

Will that work during training or Is it something that has only been implemented for inference time?
Any pointers and suggestions on how to go about this?

Thanks!
Jim

@willccbb
Copy link
Owner

willccbb commented Mar 3, 2025

Yes, this is supported! In any MultiStepEnv or SimpleEnv, you can pass vLLM sampling parameter overrides via the sampling_args dict argument; in the ToolEnv and CodeEnv, this is used to add stop strings like "</answer>" but you can also use this to add guided decoding/structured generation.

Example:

from pydantic import BaseModel
import verifiers as vf
from verifiers.tools import calculator
from verifiers.prompts import CALCULATOR_FEW_SHOT
...

regex_pattern = r'...' 
regex_gd = GuidedDecodingParams(regex=regex_pattern)

vf_env = vf.ToolEnv(
    dataset="gsm8k",
    tools=[calculator],
    sampling_args={"guided_decoding": regex_gd}
)

This should be equivalent to setting the guided_decoding parameter inside of SamplingParams for offline inference, see here: https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/structured_outputs.py

It's only used at inference, but then all your rollouts will follow the pattern, and hence the pattern is basically enforced at training time.

I will caution that writing correct schemas (especially for things like "correctly formatted code") is pretty tricky, especially for multi-turn settings where we want to allow different formats for intermediate vs final turns (code vs answer). I have found using stop tokens to be a fairly decent compromise (added these fairly recently, check if they're set in the environments in your current local version, may be good to pull again). You can also experiment with adding reward functions which check if a schema has been followed or not, at varying levels of strictness. Particularly for small models which don't follow instructions as well initially, I have found that using several increasingly-strict format reward functions can basically create a "trail of breadcrumbs" which allows the model to learn to follow more complex formats incrementally (see e.g. GRPO Llama-1B for some examples of this, you append add any functions you want to the rubric).

All of the examples/envs so far operate with XML structures, but alternatively it would be possible to add support for using JSON instead (a bit low on my priority list to add though). This would make complex structured decoding a bit more practical maybe.

@willccbb willccbb closed this as completed Mar 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants