Skip to content

Instantly share code, notes, and snippets.

@tariqul-islam
Last active October 8, 2020 01:52
Show Gist options
  • Save tariqul-islam/872c1e192e2d5c079a3fb013a5cc8d9d to your computer and use it in GitHub Desktop.
Save tariqul-islam/872c1e192e2d5c079a3fb013a5cc8d9d to your computer and use it in GitHub Desktop.
Numba implementation of Trustworthiness Score
import numpy as np
import numba
from numba import prange
#Usage: (X,Y) data
#sort_idx,_ = get_first_order_graph(X,n_neighbors=15)
#tt = trustworthiness(Y,sort_idx)
@numba.jit(nopython=True, parallel=True)
def get_first_order_graph(X,n_neighbors):
N = X.shape[0]
dist = np.zeros((N, N), dtype=np.float32)
sort_idx = np.zeros((N,n_neighbors), dtype=np.int32)
for i in range(N):
#if (i+1)%10000 == 0:
# print('Completed ', i+1, ' of ', N)
for j in prange(i+1,N):
dist[i,j] = np.sum( (X[i]-X[j])**2 )
dist[j,i] = dist[i,j]
sort_idx[i,:] = np.argsort(dist[i,:])[1:n_neighbors+1]
return sort_idx, dist
@numba.jit(nopython=True, parallel=True)
def trustworthiness(Y,sort_idx):
N = Y.shape[0]
K = sort_idx.shape[1]
#print(N,K)
val = 0.0
for i in prange(N):
dist = np.sum((Y - Y[i,:])**2,axis=1)
#print(dist.shape)
sort_idy = np.argsort(dist)[1:]
for j in prange(K):
r_0 = np.argwhere(sort_idy==sort_idx[i,j])
#print(i, r_0, sort_idx[i,j])
#print(sort_idy)
r = r_0[0,0]
#print(i,r)
#print(sort_idx[i,j])
#print(sort_idy)
r_v = r - K
if r_v>0:
val += r_v
#print(val)
val = val * 2 / ( N*K * (2.0*N - 3.0*K - 1.0))
#print(val)
return 1 - val
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment