Skip to content

Instantly share code, notes, and snippets.

@mindcrime
Created February 26, 2017 19:23
Show Gist options
  • Save mindcrime/b440cc83fde2a4cf7735fd37b31a23d8 to your computer and use it in GitHub Desktop.
Save mindcrime/b440cc83fde2a4cf7735fd37b31a23d8 to your computer and use it in GitHub Desktop.
call() method in DataSetDavaVecFunction
public DataSet call(List<Writable> currList) throws Exception {
//allow people to specify label index as -1 and infer the last possible label
int labelIndex = this.labelIndex;
if (numPossibleLabels >= 1 && labelIndex < 0) {
labelIndex = currList.size() - 1;
}
INDArray label = null;
INDArray featureVector = null;
int featureCount = 0;
int labelCount = 0;
//no labels
if(currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
NDArrayWritable writable = (NDArrayWritable)currList.get(0);
return new DataSet(writable.get(),writable.get());
}
if(currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
if(!regression)
label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()),numPossibleLabels);
else
label = Nd4j.scalar(Double.parseDouble(currList.get(1).toString()));
NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
featureVector = ndArrayWritable.get();
return new DataSet(featureVector,label);
}
for (int j = 0; j < currList.size(); j++) {
Writable current = currList.get(j);
//ndarray writable is an insane slow down herecd
if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
continue;
if (labelIndex >= 0 && j >= labelIndex && j<= labelIndexTo ) {
//single label case (classification, single label regression etc)
if (converter != null) {
try {
current = converter.convert(current);
} catch (WritableConverterException e) {
e.printStackTrace();
}
}
if(regression){
//single and multi-label regression
if(label == null){
label = Nd4j.zeros(labelIndexTo-labelIndex+1);
}
label.putScalar(0,labelCount++, current.toDouble());
} else {
if (numPossibleLabels < 1)
throw new IllegalStateException("Number of possible labels invalid, must be >= 1 for classification");
int curr = current.toInt();
if (curr >= numPossibleLabels)
throw new IllegalStateException("Invalid index: got index " + curr + " but numPossibleLabels is " + numPossibleLabels + " (must be 0 <= idx < numPossibleLabels");
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
}
} else {
try {
double value = current.toDouble();
if (featureVector == null) {
if(regression && labelIndex >= 0){
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
}
featureVector.putScalar(featureCount++, value);
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
assert featureVector == null;
featureVector = ((NDArrayWritable)current).get();
} else {
throw e;
}
}
}
}
DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector) );
if(preProcessor != null) preProcessor.preProcess(ds);
return ds;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment