Skip to content

Instantly share code, notes, and snippets.

@arnabkaycee
Created January 13, 2021 13:22
Show Gist options
  • Save arnabkaycee/4070a0f117256b763abb324f4014f843 to your computer and use it in GitHub Desktop.
Save arnabkaycee/4070a0f117256b763abb324f4014f843 to your computer and use it in GitHub Desktop.
Merkle Root
package com.arnabchatterjee.core;

import com.arnabchatterjee.beans.KVPair;
import com.arnabchatterjee.beans.LeafNode;
import com.arnabchatterjee.beans.Node;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.bouncycastle.util.encoders.Hex;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class MerkleTreeBuilder {


    private static final Map<String, LeafNode> leafNodeMap = new HashMap<>();

    /**
     * Build a Merkle Tree from a list of Attribute key value pairs
     * @param srcKvPairs the key value pairs
     * @return Root node of the merkle tree
     */
    public static Node buildMerkleTree(List<KVPair> srcKvPairs) {

        List<? super Node> srcNodeArray = srcKvPairs.stream()
                .map(LeafNode::new)
                .peek(
                        leafNode -> {
                            leafNode.setHash(getHash(leafNode.getKvPair().toString()));
                            leafNodeMap.put(leafNode.getKvPair().getName(), leafNode);
                    })
                .collect(Collectors.toList());

        if (srcNodeArray.size() % 2 != 0) srcNodeArray.add(getDummyNode());

        int height = (int) Math.ceil(Math.log(srcNodeArray.size())/Math.log(2));

        for (int j = height-1; j >= 0; j--){

            List<? super Node> destNodeArray = new ArrayList<>();

            for (int i = 0; i< srcNodeArray.size(); i+=2){

                    Node newNode = new Node();
                    Node leftNode = (Node) srcNodeArray.get(i);
                    Node rightNode = (Node) srcNodeArray.get(i+1);
                    leftNode.setParentNode(newNode);
                    rightNode.setParentNode(newNode);
                    String newNodeHash = getHash(leftNode.getHash().concat(rightNode.getHash()));

                    newNode.setLeftNode(leftNode);
                    newNode.setRightNode(rightNode);
                    newNode.setHash(newNodeHash);
                    destNodeArray.add(newNode);
            }
            srcNodeArray = destNodeArray;
            if (srcNodeArray.size() % 2 != 0) srcNodeArray.add(getDummyNode());
        }
        return (Node) srcNodeArray.get(0);
    }

    private static Node getDummyNode(){
        return new Node(getHash("dummy"));
    }


    private static List<Node> getMerkleProofNodes (String key, Node rootNode) {
        Node searchedLeafNode;
        if (leafNodeMap.containsKey(key)) {
            searchedLeafNode = leafNodeMap.get(key);
        }
        else {
            throw new RuntimeException("Node not found");
        }
        List <Node> merkleNodeProof = new ArrayList<>();
        while(!searchedLeafNode.getHash().equals(rootNode.getHash())){
            merkleNodeProof.add(getSiblingNode(searchedLeafNode));
            searchedLeafNode = searchedLeafNode.getParentNode();
        }
        return merkleNodeProof;
    }

    private static Node getSiblingNode(Node node){
        Node parentNode = node.getParentNode();
        Node leftChild = parentNode.getLeftNode();
        Node rightChild = parentNode.getRightNode();
        if (node.getHash().equals(leftChild.getHash())){
            rightChild.setLeftChild(false);
            return rightChild;
        }
        else{
            leftChild.setLeftChild(true);
            return leftChild;
        }
    }

    private static String getHash(String originalString) {
        MessageDigest digest = null;
        try {
            digest = MessageDigest.getInstance("SHA-256");
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        assert digest != null;
        byte[] hash = digest.digest(
                originalString.getBytes(StandardCharsets.UTF_8));
        return new String(Hex.encode(hash));
    }


    private static boolean verifyMerkleProof(String data, List<Node> merkleProof, Node merkleRoot) {
        String hash = getHash(data);

        for(Node merkleProofElement : merkleProof){

            if (merkleProofElement.isLeftChild()){
                hash = getHash(merkleProofElement.getHash()+hash);
            }else{
                hash = getHash(hash+merkleProofElement.getHash());
            }
        }
        return merkleRoot.getHash().equals(hash);
    }

    public static void main(String[] args) throws JsonProcessingException {
        KVPair p1 = new KVPair("key1", "val1");
        KVPair p2 = new KVPair("key2", "val2");
        KVPair p3 = new KVPair("key3", "val3");
        KVPair p4 = new KVPair("key4", "val4");
        KVPair p5 = new KVPair("key5", "val5");
        KVPair p6 = new KVPair("key6", "val6");
        KVPair p7 = new KVPair("key7", "val7");
        Node root = MerkleTreeBuilder.buildMerkleTree(List.of(p1,p2,p3,p4,p5,p6,p7));
        System.out.println(root);
        String json = new ObjectMapper().writeValueAsString(root);
        System.out.println(json);

        System.out.println("============");
        List<Node> merkleProofNodes = getMerkleProofNodes("key3", root);
        for (Node node : merkleProofNodes){
            System.out.println(node.getHash()+" "+(node.isLeftChild()?"L":"R"));
        }
        System.out.println("=============");
        boolean isVerified = verifyMerkleProof("key3val3", merkleProofNodes, root);
        System.out.println(isVerified);
    }

}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment