Skip to content

Instantly share code, notes, and snippets.

@ThomasMiconi
Last active April 11, 2023 19:05
Show Gist options
  • Save ThomasMiconi/1065370cf25262d4d2e30c2b9519ebe0 to your computer and use it in GitHub Desktop.
Save ThomasMiconi/1065370cf25262d4d2e30c2b9519ebe0 to your computer and use it in GitHub Desktop.
GPT3.5 doing multiple linear regression in-context
import os
import openai
import numpy as np
import matplotlib.pyplot as plt
# Load your API key from an environment variable or secret management service
openai.api_key = os.getenv("OPENAI_API_KEY")
NBTRAIN = 100
NBTEST = 30
NOISEMULT = .2
TESTMULT = 4.0
TRAINMULT = 2.0
# The function to be regressed:
def fct(in1, in2):
return .3 * in1 + .7 * in2
# Training data
x1 = (100 * TRAINMULT * np.random.rand(NBTRAIN)).astype(int)/100.0 # np.random.rand(NBTRAIN)
x2 = (100 * TRAINMULT * np.random.rand(NBTRAIN)).astype(int)/100.0 # np.random.rand(NBTRAIN)
y = fct(x1, x2)
ynoise = y + NOISEMULT * (2.0 * np.random.rand(y.size) - 1.0)
# Generating test data
x1x2 = zip(x1, x2)
x1t = []
x2t = []
for numtest in range(NBTEST):
while True:
a = int(100 * TESTMULT * np.random.rand())/100.0 # np.random.rand()
b = int(100 * TESTMULT * np.random.rand())/100.0
if not (a,b) in x1x2:
break
x1t.append(a)
x2t.append(b)
x1t =np.array(x1t); x2t = np.array(x2t)
yt = fct(x1t, x2t)
linestrain = []
linestest = []
for nn in range(NBTRAIN):
linestrain.append(f"If x1 = {x1[nn]:.2f} and x2 = {x2[nn]:.2f}, then y = {ynoise[nn]:.2f}.")
for nn in range(NBTEST):
linestest.append(f"If x1 = {x1t[nn]:.2f} and x2 = {x2t[nn]:.2f}, then y =")
prompt_base = "\n".join(linestrain) +"\n"
responses = []
# Prompts for each test case incclude the whole training data, plus one line from the test data (without the y!)
for nt in range(NBTEST):
prompt = prompt_base + linestest[nt]
print("Prompt", nt, ":")
print(prompt + "(End.)")
response = openai.Completion.create(model="text-davinci-003", prompt=prompt, temperature=0, max_tokens=7)
responses.append(response['choices'][0]['text'][1:5])
ytpred = np.array([float(x) for x in responses])
plt.figure(); plt.plot(ynoise, y, '.b', label='Train data (prompt)')
plt.plot(yt, ytpred, '.r', label='Test data (completion)')
plt.ylabel("Predicted y")
plt.xlabel("True y")
plt.title('y = .3*x1 + .7*x2 + N(0,0.2)')
plt.legend()
plt.show()
print("End.")
@ThomasMiconi
Copy link
Author

GPT can read a bunch of triplets (x1, x2, y) from its input, perform multiple linear regression, and predict future outputs with reasonable accuracy, even for new x outside the prompt data (within reason!)

Example prompt, using the function y = .3x1 +.7x2 + noise (use many more triplets!):

If x1 = 0.02 and x2 = 1.07, then y = 0.59.
If x1 = 0.77 and x2 = 0.24, then y = 0.39.
If x1 = 1.13 and x2 = 1.34, then y = 1.38.
If x1 = 1.78 and x2 = 1.31, then y = 1.36.
If x1 = 1.47 and x2 = 0.30, then y =

gptlinreg3

@ThomasMiconi
Copy link
Author

ThomasMiconi commented Apr 11, 2023

Example of a full prompt (predicts one data point):

If x1 = 0.97 and x2 = 0.43, then y = 0.66.
If x1 = 0.95 and x2 = 1.13, then y = 1.00.
If x1 = 0.86 and x2 = 0.37, then y = 0.40.
If x1 = 1.16 and x2 = 1.24, then y = 1.24.
If x1 = 1.67 and x2 = 1.72, then y = 1.58.
If x1 = 0.01 and x2 = 1.75, then y = 1.38.
If x1 = 0.16 and x2 = 1.63, then y = 1.21.
If x1 = 1.26 and x2 = 0.73, then y = 0.73.
If x1 = 0.75 and x2 = 1.67, then y = 1.29.
If x1 = 1.01 and x2 = 1.50, then y = 1.18.
If x1 = 0.14 and x2 = 1.42, then y = 1.23.
If x1 = 1.18 and x2 = 1.74, then y = 1.38.
If x1 = 1.29 and x2 = 1.82, then y = 1.69.
If x1 = 1.54 and x2 = 0.22, then y = 0.47.
If x1 = 0.77 and x2 = 1.79, then y = 1.65.
If x1 = 0.06 and x2 = 1.80, then y = 1.17.
If x1 = 1.38 and x2 = 1.26, then y = 1.31.
If x1 = 0.91 and x2 = 0.01, then y = 0.32.
If x1 = 1.21 and x2 = 0.62, then y = 0.91.
If x1 = 1.16 and x2 = 1.81, then y = 1.46.
If x1 = 1.35 and x2 = 1.96, then y = 1.79.
If x1 = 1.35 and x2 = 1.84, then y = 1.57.
If x1 = 1.77 and x2 = 1.02, then y = 1.30.
If x1 = 1.69 and x2 = 0.67, then y = 0.92.
If x1 = 0.03 and x2 = 0.91, then y = 0.57.
If x1 = 0.70 and x2 = 1.12, then y = 0.95.
If x1 = 0.83 and x2 = 1.81, then y = 1.69.
If x1 = 0.31 and x2 = 1.94, then y = 1.55.
If x1 = 0.39 and x2 = 1.67, then y = 1.14.
If x1 = 1.54 and x2 = 0.29, then y = 0.49.
If x1 = 0.46 and x2 = 0.03, then y = 0.26.
If x1 = 0.87 and x2 = 0.21, then y = 0.53.
If x1 = 1.44 and x2 = 1.57, then y = 1.37.
If x1 = 1.65 and x2 = 0.16, then y = 0.60.
If x1 = 0.53 and x2 = 1.88, then y = 1.56.
If x1 = 1.45 and x2 = 0.49, then y = 0.62.
If x1 = 0.73 and x2 = 0.87, then y = 0.96.
If x1 = 0.83 and x2 = 0.51, then y = 0.60.
If x1 = 0.20 and x2 = 1.37, then y = 1.05.
If x1 = 0.42 and x2 = 1.09, then y = 0.99.
If x1 = 0.17 and x2 = 0.05, then y = 0.13.
If x1 = 0.55 and x2 = 1.17, then y = 0.90.
If x1 = 0.63 and x2 = 1.56, then y = 1.13.
If x1 = 1.81 and x2 = 0.47, then y = 0.77.
If x1 = 0.48 and x2 = 0.14, then y = 0.24.
If x1 = 0.85 and x2 = 0.38, then y = 0.48.
If x1 = 1.92 and x2 = 0.73, then y = 1.19.
If x1 = 0.83 and x2 = 0.41, then y = 0.73.
If x1 = 1.22 and x2 = 0.35, then y = 0.76.
If x1 = 1.92 and x2 = 0.90, then y = 1.10.
If x1 = 0.54 and x2 = 1.68, then y = 1.27.
If x1 = 1.03 and x2 = 0.12, then y = 0.41.
If x1 = 1.14 and x2 = 1.87, then y = 1.50.
If x1 = 0.46 and x2 = 1.17, then y = 1.02.
If x1 = 1.93 and x2 = 0.18, then y = 0.87.
If x1 = 1.38 and x2 = 1.29, then y = 1.14.
If x1 = 0.40 and x2 = 1.24, then y = 0.80.
If x1 = 0.83 and x2 = 0.96, then y = 1.12.
If x1 = 0.02 and x2 = 1.77, then y = 1.33.
If x1 = 1.43 and x2 = 1.76, then y = 1.81.
If x1 = 0.79 and x2 = 1.67, then y = 1.41.
If x1 = 1.44 and x2 = 1.89, then y = 1.82.
If x1 = 0.18 and x2 = 1.98, then y = 1.36.
If x1 = 1.39 and x2 = 1.28, then y = 1.35.
If x1 = 0.83 and x2 = 1.25, then y = 1.00.
If x1 = 0.68 and x2 = 1.49, then y = 1.43.
If x1 = 0.10 and x2 = 0.20, then y = 0.15.
If x1 = 0.34 and x2 = 0.96, then y = 0.95.
If x1 = 0.33 and x2 = 1.59, then y = 1.37.
If x1 = 1.33 and x2 = 0.64, then y = 1.04.
If x1 = 1.13 and x2 = 1.36, then y = 1.30.
If x1 = 1.42 and x2 = 0.97, then y = 0.95.
If x1 = 0.86 and x2 = 0.88, then y = 0.89.
If x1 = 0.26 and x2 = 1.81, then y = 1.31.
If x1 = 0.86 and x2 = 0.74, then y = 0.74.
If x1 = 1.71 and x2 = 0.93, then y = 1.24.
If x1 = 1.57 and x2 = 0.65, then y = 1.00.
If x1 = 0.13 and x2 = 0.25, then y = 0.07.
If x1 = 0.32 and x2 = 0.83, then y = 0.56.
If x1 = 1.87 and x2 = 0.59, then y = 1.16.
If x1 = 1.96 and x2 = 0.83, then y = 1.19.
If x1 = 0.90 and x2 = 1.96, then y = 1.83.
If x1 = 0.02 and x2 = 1.07, then y = 0.59.
If x1 = 0.77 and x2 = 0.24, then y = 0.39.
If x1 = 0.36 and x2 = 1.93, then y = 1.53.
If x1 = 1.13 and x2 = 1.34, then y = 1.38.
If x1 = 1.78 and x2 = 1.31, then y = 1.36.
If x1 = 1.47 and x2 = 0.30, then y =

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment