Created
February 26, 2017 19:23
-
-
Save mindcrime/b440cc83fde2a4cf7735fd37b31a23d8 to your computer and use it in GitHub Desktop.
call() method in DataSetDavaVecFunction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public DataSet call(List<Writable> currList) throws Exception { | |
//allow people to specify label index as -1 and infer the last possible label | |
int labelIndex = this.labelIndex; | |
if (numPossibleLabels >= 1 && labelIndex < 0) { | |
labelIndex = currList.size() - 1; | |
} | |
INDArray label = null; | |
INDArray featureVector = null; | |
int featureCount = 0; | |
int labelCount = 0; | |
//no labels | |
if(currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) { | |
NDArrayWritable writable = (NDArrayWritable)currList.get(0); | |
return new DataSet(writable.get(),writable.get()); | |
} | |
if(currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) { | |
if(!regression) | |
label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()),numPossibleLabels); | |
else | |
label = Nd4j.scalar(Double.parseDouble(currList.get(1).toString())); | |
NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0); | |
featureVector = ndArrayWritable.get(); | |
return new DataSet(featureVector,label); | |
} | |
for (int j = 0; j < currList.size(); j++) { | |
Writable current = currList.get(j); | |
//ndarray writable is an insane slow down herecd | |
if (!(current instanceof NDArrayWritable) && current.toString().isEmpty()) | |
continue; | |
if (labelIndex >= 0 && j >= labelIndex && j<= labelIndexTo ) { | |
//single label case (classification, single label regression etc) | |
if (converter != null) { | |
try { | |
current = converter.convert(current); | |
} catch (WritableConverterException e) { | |
e.printStackTrace(); | |
} | |
} | |
if(regression){ | |
//single and multi-label regression | |
if(label == null){ | |
label = Nd4j.zeros(labelIndexTo-labelIndex+1); | |
} | |
label.putScalar(0,labelCount++, current.toDouble()); | |
} else { | |
if (numPossibleLabels < 1) | |
throw new IllegalStateException("Number of possible labels invalid, must be >= 1 for classification"); | |
int curr = current.toInt(); | |
if (curr >= numPossibleLabels) | |
throw new IllegalStateException("Invalid index: got index " + curr + " but numPossibleLabels is " + numPossibleLabels + " (must be 0 <= idx < numPossibleLabels"); | |
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels); | |
} | |
} else { | |
try { | |
double value = current.toDouble(); | |
if (featureVector == null) { | |
if(regression && labelIndex >= 0){ | |
//Handle the possibly multi-label regression case here: | |
int nLabels = labelIndexTo - labelIndex + 1; | |
featureVector = Nd4j.create(1, currList.size() - nLabels); | |
} else { | |
//Classification case, and also no-labels case | |
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size()); | |
} | |
} | |
featureVector.putScalar(featureCount++, value); | |
} catch (UnsupportedOperationException e) { | |
// This isn't a scalar, so check if we got an array already | |
if (current instanceof NDArrayWritable) { | |
assert featureVector == null; | |
featureVector = ((NDArrayWritable)current).get(); | |
} else { | |
throw e; | |
} | |
} | |
} | |
} | |
DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector) ); | |
if(preProcessor != null) preProcessor.preProcess(ds); | |
return ds; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment