Skip to content

Instantly share code, notes, and snippets.

@jcdiv47
Last active June 13, 2024 02:00
Show Gist options
  • Save jcdiv47/d889622f7e5d4c8197ac2d54b296d5e9 to your computer and use it in GitHub Desktop.
Save jcdiv47/d889622f7e5d4c8197ac2d54b296d5e9 to your computer and use it in GitHub Desktop.
Stopping criteria for text-to-sql task with text-generation pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, \
GenerationConfig, BitsAndBytesConfig, StoppingCriteria, \
TextStreamer, pipeline
import torch
class GenerateSqlStoppingCriteria(StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
# stops when sequence "```\n" is generated
# Baichuan2 tokenizer
# ``` -> 84
# \n -> 5
return (
len(input_ids[0]) > 1
and input_ids[0][-1] == 5
and input_ids[0][-2] == 84
)
def __len__(self):
return 1
def __iter__(self):
yield self
model_id = "baichuan-inc/Baichuan2-13B-chat"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False,
trust_remote_code=True,
revision="v2.0"
)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=True,
)
model.generation_config = GenerationConfig.from_pretrained(model_id, revision="v2.0")
streamer = TextStreamer(tokenizer, skip_prompt=True,)
pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
revision="v2.0",
do_sample=False,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=GenerateSqlStoppingCriteria(),
streamer=streamer,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment