Skip to content

Instantly share code, notes, and snippets.

Last active January 19, 2024 02:44
Show Gist options
  • Save thistleknot/9ec0d422b96e3c159982eb51450fe40d to your computer and use it in GitHub Desktop.
Save thistleknot/9ec0d422b96e3c159982eb51450fe40d to your computer and use it in GitHub Desktop.
Efficient Batching v2
#This method deducts from the list sent in (splitting the records between sample and remainder).
#Always 100% full of data until no more samples can be extracted where an empty sample along with the remainder are returned [where the remainder is to be folded into a new iteration]
# Function to find the combination of values that adds up to the target sum
def find_combination_to_sum(counts, target):
#print("Target inside function (find_combination_to_sum):", target)
values = []
for val, count in counts.items():
#print(f"Value (val): {val}, Type: {type(val)}")
#print(f"Count: {count}, Type: {type(count)}")
#print(f"Target // val: {target // val}, Type of target // val: {type(target // val)}")
values.extend([val] * min(count, target // val))
# Initialize the DP table
n = len(values)
dp = [[False] * (target + 1) for _ in range(n + 1)]
# Base case: target sum 0 is always achievable (by choosing nothing)
for i in range(n + 1):
dp[i][0] = True
# Build the DP table
for i in range(1, n + 1):
for j in range(1, target + 1):
dp[i][j] = dp[i - 1][j]
if values[i - 1] <= j:
dp[i][j] |= dp[i - 1][j - values[i - 1]]
# Check if the target sum is possible
if not dp[n][target]:
return None
# Trace back the solution
result = []
i, j = n, target
while i > 0 and j > 0:
if dp[i][j] != dp[i - 1][j]:
result.append(values[i - 1])
j -= values[i - 1]
i -= 1
return result
def sample_and_remove(combination, records):
# Group records by their length
grouped_records = defaultdict(list)
for record in records:
sampled_records = []
for lens_size in combination:
# Check if there are enough records of this lens size
if grouped_records[lens_size]:
# Sample one record of this lens size
sample = random.sample(grouped_records[lens_size], 1)[0]
# Add to sampled records
# Remove this record from the grouped records
# Flatten the grouped records back to a single list
modified_records = [item for sublist in grouped_records.values() for item in sublist]
return sampled_records, modified_records
return [], records
def create_batches_v2(records, block_size, num_batches):
#print("block_size in create_batches_v2:", block_size)
#print("num_batches in create_batches_v2:", num_batches)
samples = []
modified_records = records.copy()
for r in range(0, num_batches):
sample, modified_records = retrieve_sample(modified_records, block_size, num_batches)
return [], records
return [], records
return samples, modified_records
def retrieve_sample(records, block_size, num_batches):
#print("block_size in retrieve_sample:", block_size)
lens = [len(s) for s in records]
# Assuming 'lens' is a list containing your data
grouped = pd.DataFrame(lens, columns=['lens']).groupby('lens').size()
# Convert to dictionary
counts_dict = grouped.to_dict()
combination = find_combination_to_sum(counts_dict, block_size)
sample, records = sample_and_remove(combination, records)
return sample, records
train_tokenized = [record + [tokenizer.eos_token_id] for record in train_tokenized if len(record) + 1 <= args.block_size]
val_tokenized = [record + [tokenizer.eos_token_id] for record in val_tokenized if len(record) + 1 <= args.block_size]
sampled_train, remainder = create_batches_v2(train_tokenized, args.block_size, args.batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment