Created
May 3, 2023 21:10
-
-
Save snellingio/5cee5e8a98fd475886428758ac8aead3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?php | |
use Vector; | |
use PHPUnit\Framework\Assert; | |
expect()->extend('toBeWithin', function (float $expected, float $tolerance): void { | |
$actual = $this->value; | |
Assert::assertGreaterThanOrEqual($expected - $tolerance, $actual, "Expected {$actual} to be within {$tolerance} of {$expected}"); | |
Assert::assertLessThanOrEqual($expected + $tolerance, $actual, "Expected {$actual} to be within {$tolerance} of {$expected}"); | |
}); | |
it('calculates cosine similarity of two identical vectors', function () { | |
$vectorA = [1, 2, 3]; | |
$vectorB = [1, 2, 3]; | |
$similarity = Vector::cosineSimilarity($vectorA, $vectorB); | |
expect($similarity)->toBeWithin(1.0, 0.0001); | |
}); | |
it('calculates cosine similarity of two orthogonal vectors', function () { | |
$vectorA = [1, 0]; | |
$vectorB = [0, 1]; | |
$similarity = Vector::cosineSimilarity($vectorA, $vectorB); | |
expect($similarity)->toBeWithin(0.0, 0.0001); | |
}); | |
it('calculates cosine similarity of two arbitrary vectors', function () { | |
$vectorA = [1, 2, 3]; | |
$vectorB = [4, 5, 6]; | |
$similarity = Vector::cosineSimilarity($vectorA, $vectorB); | |
expect($similarity)->toBeWithin(0.9746, 0.0001); | |
}); | |
it('throws an exception for vectors of different lengths', function () { | |
$vectorA = [1, 2, 3]; | |
$vectorB = [4, 5]; | |
expect(fn() => Vector::cosineSimilarity($vectorA, $vectorB)) | |
->toThrow(InvalidArgumentException::class); | |
}); | |
it('throws an exception for vectors containing non-numeric elements', function () { | |
$vectorA = [1, 2, 'foo']; | |
$vectorB = [4, 5, 6]; | |
expect(fn() => Vector::cosineSimilarity($vectorA, $vectorB)) | |
->toThrow(InvalidArgumentException::class); | |
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?php | |
class Vector | |
{ | |
/** | |
* Calculates the cosine similarity between two vectors. | |
* | |
* @param array $vectorA The first vector. | |
* @param array $vectorB The second vector. | |
* @return float The cosine similarity between the two vectors. | |
* @throws \InvalidArgumentException If the vectors have different lengths or contain non-numeric values. | |
*/ | |
public static function cosineSimilarity(array $vectorA, array $vectorB): float | |
{ | |
if (! empty($vectorA) && count($vectorA) !== count($vectorB)) { | |
throw new \InvalidArgumentException('Vectors must be of the same length'); | |
} | |
if (array_filter($vectorA, 'is_numeric') !== $vectorA || array_filter($vectorB, 'is_numeric') !== $vectorB) { | |
throw new \InvalidArgumentException('Vectors must contain only numeric values'); | |
} | |
$dotProduct = static::dotProduct($vectorA, $vectorB); | |
$magnitudeA = static::magnitude($vectorA); | |
$magnitudeB = static::magnitude($vectorB); | |
return $dotProduct / ($magnitudeA * $magnitudeB); | |
} | |
/** | |
* Calculates the dot product of two vectors. | |
* | |
* @param array $vectorA The first vector. | |
* @param array $vectorB The second vector. | |
* @return float The dot product of the two vectors. | |
*/ | |
private static function dotProduct(array $vectorA, array $vectorB): float | |
{ | |
$result = 0; | |
foreach ($vectorA as $key => $value) { | |
$result += $value * $vectorB[$key]; | |
} | |
return $result; | |
} | |
/** | |
* Calculates the magnitude of a vector. | |
* | |
* @param array $vector The vector for which to calculate the magnitude. | |
* @return float The magnitude of the vector. | |
*/ | |
private static function magnitude(array $vector): float | |
{ | |
$result = 0; | |
foreach ($vector as $value) { | |
$result += $value ** 2; | |
} | |
return sqrt($result); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment