Skip to content

Instantly share code, notes, and snippets.

@cassc
Created July 22, 2024 08:28
Show Gist options
  • Save cassc/8c6868a9649a2a142cba1986bfe2ec78 to your computer and use it in GitHub Desktop.
Save cassc/8c6868a9649a2a142cba1986bfe2ec78 to your computer and use it in GitHub Desktop.
Get embeddings and calculation similarity
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
client = OpenAI()
def cal_embedding(text, model="text-embedding-ada-002"):
text = text.replace("\n", " ")
embedding = client.embeddings.create(input = [text], model=model).data[0].embedding
return np.reshape(embedding, (1, -1))
def similarity(ta, tb):
return cosine_similarity(cal_embedding(ta), cal_embedding(tb))[0][0]
i1 = "def sum(a, b): return a-b"
i2 = "def my_sum(a, b): return a+b"
i3 = "def minus(a, b): return a-b"
a = similarity(i1, i2)
b = similarity(i1, i3)
c = similarity(i3, i2)
print(a)
print('>', b)
print(c)
if max(a, b, c) == b:
print('correct')
else:
print('incorrect')
i1 = '''
function add1(uint a, uint b) public pure returns(uint){
return a + b;
}
'''
i2 = '''
function add2(uint a, uint b) public pure returns(uint){
return a - b;
}
'''
i3 = '''
function three(uint a, uint b) public pure returns(uint){
return a + b;
}
'''
print('*' * 80)
a = similarity(i1, i2)
b = similarity(i1, i3)
c = similarity(i3, i2)
print(a)
print('>', b)
print(c)
if max(a, b, c) == b:
print('correct')
else:
print('incorrect')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment