-
-
Save ThomasMiconi/1065370cf25262d4d2e30c2b9519ebe0 to your computer and use it in GitHub Desktop.
import os | |
import openai | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# Load your API key from an environment variable or secret management service | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
NBTRAIN = 100 | |
NBTEST = 30 | |
NOISEMULT = .2 | |
TESTMULT = 4.0 | |
TRAINMULT = 2.0 | |
# The function to be regressed: | |
def fct(in1, in2): | |
return .3 * in1 + .7 * in2 | |
# Training data | |
x1 = (100 * TRAINMULT * np.random.rand(NBTRAIN)).astype(int)/100.0 # np.random.rand(NBTRAIN) | |
x2 = (100 * TRAINMULT * np.random.rand(NBTRAIN)).astype(int)/100.0 # np.random.rand(NBTRAIN) | |
y = fct(x1, x2) | |
ynoise = y + NOISEMULT * (2.0 * np.random.rand(y.size) - 1.0) | |
# Generating test data | |
x1x2 = zip(x1, x2) | |
x1t = [] | |
x2t = [] | |
for numtest in range(NBTEST): | |
while True: | |
a = int(100 * TESTMULT * np.random.rand())/100.0 # np.random.rand() | |
b = int(100 * TESTMULT * np.random.rand())/100.0 | |
if not (a,b) in x1x2: | |
break | |
x1t.append(a) | |
x2t.append(b) | |
x1t =np.array(x1t); x2t = np.array(x2t) | |
yt = fct(x1t, x2t) | |
linestrain = [] | |
linestest = [] | |
for nn in range(NBTRAIN): | |
linestrain.append(f"If x1 = {x1[nn]:.2f} and x2 = {x2[nn]:.2f}, then y = {ynoise[nn]:.2f}.") | |
for nn in range(NBTEST): | |
linestest.append(f"If x1 = {x1t[nn]:.2f} and x2 = {x2t[nn]:.2f}, then y =") | |
prompt_base = "\n".join(linestrain) +"\n" | |
responses = [] | |
# Prompts for each test case incclude the whole training data, plus one line from the test data (without the y!) | |
for nt in range(NBTEST): | |
prompt = prompt_base + linestest[nt] | |
print("Prompt", nt, ":") | |
print(prompt + "(End.)") | |
response = openai.Completion.create(model="text-davinci-003", prompt=prompt, temperature=0, max_tokens=7) | |
responses.append(response['choices'][0]['text'][1:5]) | |
ytpred = np.array([float(x) for x in responses]) | |
plt.figure(); plt.plot(ynoise, y, '.b', label='Train data (prompt)') | |
plt.plot(yt, ytpred, '.r', label='Test data (completion)') | |
plt.ylabel("Predicted y") | |
plt.xlabel("True y") | |
plt.title('y = .3*x1 + .7*x2 + N(0,0.2)') | |
plt.legend() | |
plt.show() | |
print("End.") |
Example of a full prompt (predicts one data point):
If x1 = 0.97 and x2 = 0.43, then y = 0.66.
If x1 = 0.95 and x2 = 1.13, then y = 1.00.
If x1 = 0.86 and x2 = 0.37, then y = 0.40.
If x1 = 1.16 and x2 = 1.24, then y = 1.24.
If x1 = 1.67 and x2 = 1.72, then y = 1.58.
If x1 = 0.01 and x2 = 1.75, then y = 1.38.
If x1 = 0.16 and x2 = 1.63, then y = 1.21.
If x1 = 1.26 and x2 = 0.73, then y = 0.73.
If x1 = 0.75 and x2 = 1.67, then y = 1.29.
If x1 = 1.01 and x2 = 1.50, then y = 1.18.
If x1 = 0.14 and x2 = 1.42, then y = 1.23.
If x1 = 1.18 and x2 = 1.74, then y = 1.38.
If x1 = 1.29 and x2 = 1.82, then y = 1.69.
If x1 = 1.54 and x2 = 0.22, then y = 0.47.
If x1 = 0.77 and x2 = 1.79, then y = 1.65.
If x1 = 0.06 and x2 = 1.80, then y = 1.17.
If x1 = 1.38 and x2 = 1.26, then y = 1.31.
If x1 = 0.91 and x2 = 0.01, then y = 0.32.
If x1 = 1.21 and x2 = 0.62, then y = 0.91.
If x1 = 1.16 and x2 = 1.81, then y = 1.46.
If x1 = 1.35 and x2 = 1.96, then y = 1.79.
If x1 = 1.35 and x2 = 1.84, then y = 1.57.
If x1 = 1.77 and x2 = 1.02, then y = 1.30.
If x1 = 1.69 and x2 = 0.67, then y = 0.92.
If x1 = 0.03 and x2 = 0.91, then y = 0.57.
If x1 = 0.70 and x2 = 1.12, then y = 0.95.
If x1 = 0.83 and x2 = 1.81, then y = 1.69.
If x1 = 0.31 and x2 = 1.94, then y = 1.55.
If x1 = 0.39 and x2 = 1.67, then y = 1.14.
If x1 = 1.54 and x2 = 0.29, then y = 0.49.
If x1 = 0.46 and x2 = 0.03, then y = 0.26.
If x1 = 0.87 and x2 = 0.21, then y = 0.53.
If x1 = 1.44 and x2 = 1.57, then y = 1.37.
If x1 = 1.65 and x2 = 0.16, then y = 0.60.
If x1 = 0.53 and x2 = 1.88, then y = 1.56.
If x1 = 1.45 and x2 = 0.49, then y = 0.62.
If x1 = 0.73 and x2 = 0.87, then y = 0.96.
If x1 = 0.83 and x2 = 0.51, then y = 0.60.
If x1 = 0.20 and x2 = 1.37, then y = 1.05.
If x1 = 0.42 and x2 = 1.09, then y = 0.99.
If x1 = 0.17 and x2 = 0.05, then y = 0.13.
If x1 = 0.55 and x2 = 1.17, then y = 0.90.
If x1 = 0.63 and x2 = 1.56, then y = 1.13.
If x1 = 1.81 and x2 = 0.47, then y = 0.77.
If x1 = 0.48 and x2 = 0.14, then y = 0.24.
If x1 = 0.85 and x2 = 0.38, then y = 0.48.
If x1 = 1.92 and x2 = 0.73, then y = 1.19.
If x1 = 0.83 and x2 = 0.41, then y = 0.73.
If x1 = 1.22 and x2 = 0.35, then y = 0.76.
If x1 = 1.92 and x2 = 0.90, then y = 1.10.
If x1 = 0.54 and x2 = 1.68, then y = 1.27.
If x1 = 1.03 and x2 = 0.12, then y = 0.41.
If x1 = 1.14 and x2 = 1.87, then y = 1.50.
If x1 = 0.46 and x2 = 1.17, then y = 1.02.
If x1 = 1.93 and x2 = 0.18, then y = 0.87.
If x1 = 1.38 and x2 = 1.29, then y = 1.14.
If x1 = 0.40 and x2 = 1.24, then y = 0.80.
If x1 = 0.83 and x2 = 0.96, then y = 1.12.
If x1 = 0.02 and x2 = 1.77, then y = 1.33.
If x1 = 1.43 and x2 = 1.76, then y = 1.81.
If x1 = 0.79 and x2 = 1.67, then y = 1.41.
If x1 = 1.44 and x2 = 1.89, then y = 1.82.
If x1 = 0.18 and x2 = 1.98, then y = 1.36.
If x1 = 1.39 and x2 = 1.28, then y = 1.35.
If x1 = 0.83 and x2 = 1.25, then y = 1.00.
If x1 = 0.68 and x2 = 1.49, then y = 1.43.
If x1 = 0.10 and x2 = 0.20, then y = 0.15.
If x1 = 0.34 and x2 = 0.96, then y = 0.95.
If x1 = 0.33 and x2 = 1.59, then y = 1.37.
If x1 = 1.33 and x2 = 0.64, then y = 1.04.
If x1 = 1.13 and x2 = 1.36, then y = 1.30.
If x1 = 1.42 and x2 = 0.97, then y = 0.95.
If x1 = 0.86 and x2 = 0.88, then y = 0.89.
If x1 = 0.26 and x2 = 1.81, then y = 1.31.
If x1 = 0.86 and x2 = 0.74, then y = 0.74.
If x1 = 1.71 and x2 = 0.93, then y = 1.24.
If x1 = 1.57 and x2 = 0.65, then y = 1.00.
If x1 = 0.13 and x2 = 0.25, then y = 0.07.
If x1 = 0.32 and x2 = 0.83, then y = 0.56.
If x1 = 1.87 and x2 = 0.59, then y = 1.16.
If x1 = 1.96 and x2 = 0.83, then y = 1.19.
If x1 = 0.90 and x2 = 1.96, then y = 1.83.
If x1 = 0.02 and x2 = 1.07, then y = 0.59.
If x1 = 0.77 and x2 = 0.24, then y = 0.39.
If x1 = 0.36 and x2 = 1.93, then y = 1.53.
If x1 = 1.13 and x2 = 1.34, then y = 1.38.
If x1 = 1.78 and x2 = 1.31, then y = 1.36.
If x1 = 1.47 and x2 = 0.30, then y =
GPT can read a bunch of triplets (x1, x2, y) from its input, perform multiple linear regression, and predict future outputs with reasonable accuracy, even for new x outside the prompt data (within reason!)
Example prompt, using the function y = .3x1 +.7x2 + noise (use many more triplets!):