Created
May 4, 2023 10:38
-
-
Save tempdeltavalue/aec9c3e7d9ba15433dc427c29c45c496 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func runModel(multiArr: MLMultiArray, img_w: Float, img_h:Float) { | |
let inputShape: [NSNumber] = [1 as NSNumber, | |
256 as NSNumber, | |
64 as NSNumber, | |
64 as NSNumber] | |
let input111 = try! MLMultiArray(shape: [1, 256, 64, 64], dataType: .float64) | |
// input111.withUnsafeMutableBytes { ptr, strides in | |
let data = NSMutableData(data: Data(count: 16777216)) | |
let inputTensor = try! ORTValue(tensorData: data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape) | |
let point_coords_data = NSMutableData(data: [176, 68].data) | |
let inputShape2: [NSNumber] = [1 as NSNumber, | |
1 as NSNumber, | |
2 as NSNumber] | |
let point_coords = try! ORTValue(tensorData: point_coords_data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape2) | |
// Run ORT InferenceSession | |
let point_labels_data = NSMutableData(data: [1].data) | |
let inputShape3: [NSNumber] = [1 as NSNumber, | |
1 as NSNumber] | |
let point_labels = try! ORTValue(tensorData: point_labels_data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape3) | |
let rows = 256 | |
let column = 256 | |
let zero_array_data = NSMutableData(data: [Float](repeating: 0, count: column * rows).data) | |
let inputShape4: [NSNumber] = [1 as NSNumber, | |
1 as NSNumber, | |
rows as NSNumber, | |
column as NSNumber] | |
let mask_input = try! ORTValue(tensorData: zero_array_data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape4) | |
let has_mask_input_data = NSMutableData(data: [0].data) | |
let inputShape5: [NSNumber] = [1 as NSNumber] | |
let has_mask_input = try! ORTValue(tensorData: has_mask_input_data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape5) | |
let orig_im_size_data = NSMutableData(data: [img_w, img_h].data) | |
let inputShape6: [NSNumber] = [2 as NSNumber] | |
let orig_im_size = try! ORTValue(tensorData: orig_im_size_data, | |
elementType: ORTTensorElementDataType.float, | |
shape: inputShape6) | |
let runOptions = try! ORTRunOptions() | |
try! runOptions.setLogSeverityLevel(.info) | |
let outputs = try! session.run(withInputs: ["image_embeddings": inputTensor, | |
"point_coords" : point_coords, | |
"point_labels" : point_labels, | |
"mask_input" : mask_input, | |
"has_mask_input" : has_mask_input, | |
"orig_im_size" : orig_im_size], | |
outputNames: ["masks", "iou_predictions", "low_res_masks"], | |
runOptions: runOptions) | |
let tensorData = try! outputs["masks"]!.tensorData() | |
let tensorInfo = try! outputs["masks"]!.tensorTypeAndShapeInfo() | |
let tensorData1 = try! outputs["iou_predictions"]!.tensorData() | |
let tensorInfo1 = try! outputs["iou_predictions"]!.tensorTypeAndShapeInfo() | |
let tensorData2 = try! outputs["low_res_masks"]!.tensorData() | |
let tensorInfo2 = try! outputs["low_res_masks"]!.tensorTypeAndShapeInfo() | |
let ml_arr = try! MLMultiArray(dataPointer: tensorData2.mutableBytes, | |
shape: [tensorInfo2.shape[2], tensorInfo2.shape[3]], | |
dataType: MLMultiArrayDataType.float, | |
strides:[1, 1]) | |
let t_width = tensorInfo.shape[2].intValue | |
let t_height = tensorInfo.shape[3].intValue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment