Created
August 3, 2022 14:37
-
-
Save tteofili/52a563fc67a7fc26fe27d4a69d6ec61e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Anserini: A Lucene toolkit for reproducible information retrieval research | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package io.anserini.ann; | |
import java.io.File; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.util.HashMap; | |
import java.util.LinkedList; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import org.apache.commons.io.FileUtils; | |
import org.apache.commons.io.IOUtils; | |
import org.apache.commons.lang3.time.DurationFormatUtils; | |
import org.apache.lucene.analysis.Analyzer; | |
import org.apache.lucene.document.Document; | |
import org.apache.lucene.document.Field; | |
import org.apache.lucene.document.KnnVectorField; | |
import org.apache.lucene.document.StringField; | |
import org.apache.lucene.index.IndexWriter; | |
import org.apache.lucene.index.IndexWriterConfig; | |
import org.apache.lucene.index.VectorSimilarityFunction; | |
import org.apache.lucene.store.Directory; | |
import org.apache.lucene.store.FSDirectory; | |
import org.kohsuke.args4j.CmdLineException; | |
import org.kohsuke.args4j.CmdLineParser; | |
import org.kohsuke.args4j.Option; | |
import org.kohsuke.args4j.OptionHandlerFilter; | |
import org.kohsuke.args4j.ParserProperties; | |
public class IndexVectorsHNSW { | |
public static final String FIELD_ID = "id"; | |
public static final String FIELD_VECTOR = "vector"; | |
public static final class Args { | |
@Option(name = "-input", metaVar = "[file]", required = true, usage = "vectors model") | |
public File input; | |
@Option(name = "-path", metaVar = "[path]", required = true, usage = "index path") | |
public Path path; | |
@Option(name="-stored", metaVar = "[boolean]", usage = "store vectors") | |
public boolean stored; | |
} | |
public static void main(String[] args) throws Exception { | |
IndexVectorsHNSW.Args indexArgs = new IndexVectorsHNSW.Args(); | |
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90)); | |
try { | |
parser.parseArgument(args); | |
} catch (CmdLineException e) { | |
System.err.println(e.getMessage()); | |
parser.printUsage(System.err); | |
System.err.println("Example: " + IndexVectorsHNSW.class.getSimpleName() + | |
parser.printExample(OptionHandlerFilter.REQUIRED)); | |
return; | |
} | |
Analyzer vectorAnalyzer = null; | |
final long start = System.nanoTime(); | |
System.out.println(String.format("Loading model %s", indexArgs.input)); | |
Map<String, List<float[]>> vectors = readVectors(indexArgs.input); | |
Path indexDir = indexArgs.path; | |
if (!Files.exists(indexDir)) { | |
Files.createDirectories(indexDir); | |
} | |
System.out.println(String.format("Creating index at %s...", indexArgs.path)); | |
Directory d = FSDirectory.open(indexDir); | |
IndexWriterConfig conf = new IndexWriterConfig(); | |
IndexWriter indexWriter = new IndexWriter(d, conf); | |
final AtomicInteger cnt = new AtomicInteger(); | |
for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) { | |
for (float[] vector: entry.getValue()) { | |
Document doc = new Document(); | |
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES)); | |
doc.add(new KnnVectorField(FIELD_VECTOR, vector, VectorSimilarityFunction.EUCLIDEAN)); | |
try { | |
indexWriter.addDocument(doc); | |
int cur = cnt.incrementAndGet(); | |
if (cur % 100000 == 0) { | |
System.out.println(String.format("%s docs added", cnt)); | |
} | |
} catch (IOException e) { | |
System.err.println("Error while indexing: " + e.getLocalizedMessage()); | |
} | |
} | |
} | |
indexWriter.commit(); | |
System.out.println(String.format("%s docs indexed", cnt.get())); | |
long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L); | |
System.out.println(String.format("Index size: %dMB", space)); | |
indexWriter.close(); | |
d.close(); | |
final long durationMillis = | |
TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); | |
System.out.println(String.format("Total time: %s", | |
DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss"))); | |
} | |
static Map<String, List<float[]>> readVectors(File input) throws IOException { | |
Map<String, List<float[]>> vectors = new HashMap<>(); | |
for (String line : IOUtils.readLines(new FileReader(input))) { | |
String[] s = line.split("\\s+"); | |
if (s.length > 2) { | |
String key = s[0]; | |
float[] vector = new float[s.length - 1]; | |
for (int i = 1; i < s.length; i++) { | |
float f = Float.parseFloat(s[i]); | |
vector[i - 1] = f; | |
} | |
if (vectors.containsKey(key)) { | |
List<float[]> floats = new LinkedList<>(vectors.get(key)); | |
floats.add(vector); | |
vectors.put(key, floats); | |
} else { | |
vectors.put(key, List.of(vector)); | |
} | |
} | |
} | |
return vectors; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment