Created
August 3, 2023 18:23
-
-
Save Micky774/0ee61235b751df83da6d1b93512b6ff3 to your computer and use it in GitHub Desktop.
Memory profile target for DistanceMetric32 dtype preservation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics.pairwise import pairwise_distances | |
from scipy.sparse import csr_matrix | |
from sklearn.metrics import DistanceMetric | |
import numpy as np | |
N_FEATURES = 10 | |
N_SAMPLES = 10_000 | |
def _generate_PWD_data(n_samples_X, n_samples_Y, n_features, n_classes, n_outs=1, random_state=0): | |
rng = np.random.RandomState(random_state) | |
X = rng.randn(n_samples_X, n_features) | |
Y = rng.randn(n_samples_Y, n_features) | |
y_shape = (n_samples_X,) if n_outs == 1 else (n_samples_X, n_outs) | |
y = rng.randint(n_classes, size=y_shape) | |
return X, Y, y | |
X, Y, y = _generate_PWD_data(n_samples_X=N_SAMPLES, n_samples_Y=N_SAMPLES, n_features=N_FEATURES, n_classes=2, random_state=0) | |
X = csr_matrix(X, dtype=np.float32) | |
dst = DistanceMetric.get_metric(metric="manhattan", dtype=np.float32) | |
for _ in range(5): | |
dst.pairwise(X) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment