ChatGPT for Data Augmentation

Mitigating Class Imbalance through Generative Summarization

Aaron Briel
GoPenAI

--

Image by author

In a previous article, I discussed the application of abstractive summarization for text-based data augmentation.

Imbalanced class distribution remains a classic ML problem that can result in subpar classification models. Undersampling combined with oversampling are two methods of attempting to address this issue. Techniques such as SMOTE and MLSMOTE have been proposed, but the high dimensional nature of numerical vectors created from text makes other data augmentation approaches preferable.

This led to my exploration of abstractive and extractive summarization as possible alternatives. Extractive summarization was ruled out as it would return existing sentences and deduplication preprocessing would eliminate them. Abstractive Summarization seemed particularly appealing because of its ability to generate realistic sentences of text, however, its intermittent lack of novelty in writing style led me to explore other solutions.

Prompt Engineering with ChatGPT

After experimenting with basic prompts in ChatGPT, I became intrigued by the possibility of using prompt engineering as a means of augmenting data. Here is an example, based on possible user responses in conversations with real estate agents:

In the case above, the following prompt may suffice for replication:

prompt="Create SUMMARY_COUNT unique, informally written sentences \
similar to the ones listed here:",

This approach could be described as Generative Summarization for Data Augmentation. A get_generative_summarization function illustrates the use of prompt injection of samples extracted from the dataset, along with minimum and maximum word count:

def get_generative_summarization(self, texts: List[str]) -> str:
"""
Computes generative summarization of specified text

:param texts: List of texts to create summarization for
:param debug: Whether to log output
:return: generative summarization text
"""
logger.info("Generating summarization...")
prompt = self.prompt + "\n" + "\n".join(texts)

# Set min and max word counts for summarization based on sampled
# text if not specified in constructor
if self.min_length is None and self.max_length is None:
min_length, max_length = get_min_max_word_counts(texts)
self.prompt = self.prompt.replace(":", f" with a minimum length of \
{min_length} words and a maximum length of {max_length} words:")

output = self.generator.generate_summary(prompt)
if self.debug:
logger.info(f"\nSummarized text: \n{output}")

return output

A common use case might be to augment underrepresented intents in a chatbot dataset. In light of this it made sense to simplify the expected dataset from one-hot-encoded features to a text-based classifier column. Keeping the ability to augment specific classifier values by specifying them in the classifier_values construction parameter made sense.

OpenAI requests to ChatGPT are rather simple as seen below. However, OpenAI’s ChatCompletion call can periodically fail, alleviated by wrapping the request in retry attempts:

def call_chatgpt(self, prompt: str, retry_attempts: int = 3) -> str:
"""
Calls OpenAI's ChatGPT API to generate text based on prompt.

:param prompt (:obj:`string`): Prompt to generate text from.
:param retry_attempts (:obj:`int`, `optional`, defaults to 3): Number
of retry attempts to make if OpenAI fails.
:return: response (:obj:`string`): Generated text.
"""
messages = [{"role": "user", "content": prompt}]
# OpenAI seems to intermittently fail, so we'll retry a few times
attempts = 0
wait_time = 1
while attempts < retry_attempts:
try:
response = openai.ChatCompletion.create(
model=self.model,
temperature=self.temperature,
messages=messages
)["choices"][0]["message"]["content"]
break
except (openai.error.RateLimitError,
openai.error.APIConnectionError) as err:
attempts += 1
time.sleep(wait_time * attempts)
response = ''
if self.debug:
logger.warning(f"ERROR: {err}")


return response

gensum Python Library

gensum was created to simplify the implementation of this solution for the open-source community. It’s an NLP library based on absum that uses generative summarization to perform data augmentation by oversampling under-represented classes in text classification datasets. Recent advancements in generative models such as ChatGPT make this approach optimal in achieving realistic but unique data for the augmentation process, addressing the shortcoming in novelty sometimes exhibited by absum.

gensum uses ChatGPT by default, but is designed in a modular way to allow one to use any large language models capable of generative summarization. It is format agnostic, expecting only a DataFrame containing a text and classifier column.

Installation

Install via pip and set OPENAI_API_KEY:

pip install gensum
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"

Running Code

Running the code on your own dataset is then simply a matter of importing the library’s Augmentor class and calling its gen_sum_augment method as follows:

import pandas as pd
from gensum import Augmentor

csv = 'path_to_csv'
df = pd.read_csv(csv)
augmentor = Augmentor(df)
df_augmented = augmentor.gen_sum_augment()
df_augmented.to_csv(
csv.replace('.csv', '-augmented.csv'),
encoding='utf-8',
index=False
)

An explanation of available parameters is here. As always, PRs are welcome.

Cheers!

--

--

Writer for

Machine Learning Engineer, Automation Architect and Developer across multiple industries including Fortune 500 companies