package at.gv.egiz.eaaf.core.impl.utils; import java.io.IOException; import java.io.InputStream; import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; import java.util.Set; import javax.annotation.Nonnull; /** * Java Object stream implementation with some harding features. * * @author tlenz * */ public class EaafObjectInputStream extends ObjectInputStream { private final Set> allowedClassNames; private final Class firstClassType; private final Mode modeOfOperation; private int objectDeep = 0; /** * Object input-stream with internal class validation. * * @param is Inputstream to deserialize. * @param initalClassType First class type that was found * @param classes Whitelisted classnames * @param mode Operation mode for allowed class checking * @throws IOException In case of an error */ public EaafObjectInputStream(@Nonnull InputStream is, @Nonnull Set> classes, Class initalClassType, Mode mode) throws IOException { super(is); this.allowedClassNames = classes; this.firstClassType = initalClassType; this.modeOfOperation = mode; } // Only deserialize instances of our expected class @Override protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { if (Mode.STRICT.equals(modeOfOperation) && !isValidClass(desc.getName())) { throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); } else if (Mode.TYPE_SPECIFIC.equals(modeOfOperation)) { final Class clazz = super.resolveClass(desc); if (objectDeep == 0 && !firstClassType.isAssignableFrom(clazz)) { throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); } else if (objectDeep > 0 && !(isValidClassType(clazz) || Object.class.getName().equals(desc.getName()))) { throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); } else { objectDeep++; return clazz; } } else { return super.resolveClass(desc); } } private boolean isValidClass(String classToDeserialize) { return allowedClassNames.stream() .filter(el -> el.getName().equals(classToDeserialize)) .findFirst() .isPresent(); } private boolean isValidClassType(Class clazzToCheck) { return allowedClassNames.stream() .filter(el -> el.isAssignableFrom(clazzToCheck)) .findFirst() .isPresent(); } enum Mode { STRICT, TYPE_SPECIFIC } }