From 45a7ae26f0172c636e42b530561eb11efa0083d5 Mon Sep 17 00:00:00 2001 From: Thomas <> Date: Mon, 15 May 2023 08:32:34 +0200 Subject: feat(utils): optimize custom Java serialization implementation --- .../core/impl/utils/EaafObjectInputStream.java | 82 ++++++-- .../core/impl/utils/EaafSerializationUtils.java | 60 +++++- .../test/utils/EaafSerializationUtilsTest.java | 214 +++++++++++++++++++++ 3 files changed, 335 insertions(+), 21 deletions(-) create mode 100644 eaaf_core_utils/src/test/java/at/gv/egiz/eaaf/core/test/utils/EaafSerializationUtilsTest.java (limited to 'eaaf_core_utils') diff --git a/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafObjectInputStream.java b/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafObjectInputStream.java index e15c7a37..1924e165 100644 --- a/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafObjectInputStream.java +++ b/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafObjectInputStream.java @@ -5,35 +5,87 @@ import java.io.InputStream; import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; -import java.util.List; +import java.util.Set; import javax.annotation.Nonnull; +/** + * Java Object stream implementation with some harding features. + * + * @author tlenz + * + */ public class EaafObjectInputStream extends ObjectInputStream { - private List allowedClassNames; - + private final Set> allowedClassNames; + private final Class firstClassType; + private final Mode modeOfOperation; + private int objectDeep = 0; + /** * Object input-stream with internal class validation. - * - * @param is Inputstream to deserialize. - * @param classNames Whitelisted classnames + * + * @param is Inputstream to deserialize. + * @param initalClassType First class type that was found + * @param classes Whitelisted classnames + * @param mode Operation mode for allowed class checking * @throws IOException In case of an error - */ - public EaafObjectInputStream(@Nonnull InputStream is, @Nonnull List classNames) throws IOException { + */ + public EaafObjectInputStream(@Nonnull InputStream is, @Nonnull Set> classes, + Class initalClassType, Mode mode) + throws IOException { super(is); - this.allowedClassNames = classNames; - + this.allowedClassNames = classes; + this.firstClassType = initalClassType; + this.modeOfOperation = mode; + } - //Only deserialize instances of our expected class + // Only deserialize instances of our expected class @Override protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - if (!allowedClassNames.contains(desc.getName())) { - throw new InvalidClassException("Unauthorized deserialization attempt: {}",desc.getName()); - + if (Mode.STRICT.equals(modeOfOperation) && !isValidClass(desc.getName())) { + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); + + } else if (Mode.TYPE_SPECIFIC.equals(modeOfOperation)) { + final Class clazz = super.resolveClass(desc); + if (objectDeep == 0 && !firstClassType.isAssignableFrom(clazz)) { + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); + + } else if (objectDeep > 0 + && !(isValidClassType(clazz) || Object.class.getName().equals(desc.getName()))) { + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); + + } else { + objectDeep++; + return clazz; + + } + + } else { + return super.resolveClass(desc); + } - return super.resolveClass(desc); + } + + private boolean isValidClass(String classToDeserialize) { + return allowedClassNames.stream() + .filter(el -> el.getName().equals(classToDeserialize)) + .findFirst() + .isPresent(); + + } + + private boolean isValidClassType(Class clazzToCheck) { + return allowedClassNames.stream() + .filter(el -> el.isAssignableFrom(clazzToCheck)) + .findFirst() + .isPresent(); + + } + + enum Mode { + STRICT, TYPE_SPECIFIC } } diff --git a/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafSerializationUtils.java b/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafSerializationUtils.java index e15c6800..efb4c9be 100644 --- a/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafSerializationUtils.java +++ b/eaaf_core_utils/src/main/java/at/gv/egiz/eaaf/core/impl/utils/EaafSerializationUtils.java @@ -5,10 +5,12 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.util.List; +import java.util.Set; import org.springframework.lang.Nullable; +import at.gv.egiz.eaaf.core.impl.utils.EaafObjectInputStream.Mode; + public class EaafSerializationUtils { private EaafSerializationUtils() { @@ -42,20 +44,65 @@ public class EaafSerializationUtils { } /** - * Deserialize the byte array into an object. + * Deserialize the byte array into an object with strict allow-list of classes. + * + *

+ * Allow all classes that exact match to elements in allow-list. + *

* - * @param bytes a serialized object - * @param allowedClassName List of classnames that are allowed for deserialization + * @param bytes a serialized object + * @param allowedClassName List of classnames that are explicit allowed for + * deserialization * @return the result of deserializing the bytes */ @Nullable - public static Object deserialize(@Nullable byte[] bytes, List allowedClassName) { + public static Object strictDeserialize(@Nullable byte[] bytes, Set> allowedClassName) { + if (bytes == null) { + return null; + + } + + try (ObjectInputStream ois = new EaafObjectInputStream(new ByteArrayInputStream(bytes), + allowedClassName, null, Mode.STRICT)) { + return ois.readObject(); + + } catch (final IOException ex) { + throw new IllegalArgumentException("Failed to deserialize object", ex); + + } catch (final ClassNotFoundException ex) { + throw new IllegalStateException("Failed to deserialize object type", ex); + + } + } + + /** + * Deserialize the byte array into an object with type-specific allow-list of + * classes. + * + *

+ * Allow all classes that the same or a super-type of elements in + * allow-list.
+ * Hint: Do NOT set {@link Object} as allowed class, because any class is + * an super-type of {@link Object}. This method implementation allows + * {@link Object} as explicit type with strict check-mode. + *

+ * + * @param bytes a serialized object + * @param allowedClassName List of classes that are explicit allowed for + * deserialization + * @param initalClassType First / Initial class type that are required + * @return the result of deserializing the bytes + */ + @Nullable + public static Object typeSpecificDeserialize(@Nullable byte[] bytes, Set> allowedClassName, + Class initalClassType) { if (bytes == null) { return null; } - try (ObjectInputStream ois = new EaafObjectInputStream(new ByteArrayInputStream(bytes), allowedClassName)) { + try (ObjectInputStream ois = new EaafObjectInputStream(new ByteArrayInputStream(bytes), + allowedClassName, initalClassType, Mode.TYPE_SPECIFIC)) { return ois.readObject(); } catch (final IOException ex) { @@ -66,4 +113,5 @@ public class EaafSerializationUtils { } } + } diff --git a/eaaf_core_utils/src/test/java/at/gv/egiz/eaaf/core/test/utils/EaafSerializationUtilsTest.java b/eaaf_core_utils/src/test/java/at/gv/egiz/eaaf/core/test/utils/EaafSerializationUtilsTest.java new file mode 100644 index 00000000..98747b41 --- /dev/null +++ b/eaaf_core_utils/src/test/java/at/gv/egiz/eaaf/core/test/utils/EaafSerializationUtilsTest.java @@ -0,0 +1,214 @@ +package at.gv.egiz.eaaf.core.test.utils; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.Serializable; +import java.util.Collections; + +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.BlockJUnit4ClassRunner; + +import com.google.common.collect.Sets; + +import at.gv.egiz.eaaf.core.exceptions.EaafAuthenticationException; +import at.gv.egiz.eaaf.core.exceptions.EaafException; +import at.gv.egiz.eaaf.core.impl.utils.EaafSerializationUtils; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; + +@RunWith(BlockJUnit4ClassRunner.class) +public class EaafSerializationUtilsTest { + + @Test + public void nullInput() { + assertNull(EaafSerializationUtils.serialize(null)); + assertNull(EaafSerializationUtils.strictDeserialize(null, Sets.newHashSet(Throwable.class))); + assertNull(EaafSerializationUtils.typeSpecificDeserialize(null, + Sets.newHashSet(Throwable.class), String.class)); + + } + + @Test + public void strictMode() { + DummyClassA dummyA = new DummyClassA(rand()); + dummyA.setMemberB(new DummyClassB(rand())); + + byte[] object = EaafSerializationUtils.serialize(dummyA); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.strictDeserialize( + object, Sets.newHashSet(Throwable.class))); + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.strictDeserialize( + object, Sets.newHashSet(DummyClassA.class))); + + assertNotNull(EaafSerializationUtils.strictDeserialize(object, Sets.newHashSet( + DummyClassA.class, DummyClassB.class))); + + } + + @Test + public void typeModeSimple() { + DummyClassC dummyC = new DummyClassC(rand()); + + byte[] object = EaafSerializationUtils.serialize(dummyC); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class), String.class)); + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassB.class), DummyClassC.class)); + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassC.class), DummyClassA.class)); + + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class), DummyClassC.class)); + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class), DummyClassA.class)); + + } + + @Test + public void typeModeComplex() { + + DummyClassC dummyC = new DummyClassC(rand()); + dummyC.setMemberB(new DummyClassB(rand())); + + byte[] object = EaafSerializationUtils.serialize(dummyC); + + // missing DummyClassB + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class), DummyClassC.class)); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassC.class, DummyClassB.class), DummyClassC.class)); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassB.class, DummyClassA.class), DummyClassB.class)); + + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class), DummyClassA.class)); + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class), DummyClassC.class)); + + } + + @Test + public void typeModeComplexWithObject() { + + DummyClassC dummyC = new DummyClassC(rand()); + DummyClassD dummyD = new DummyClassD(rand()); + dummyC.setMemberB(dummyD); + dummyD.setAnyType(new EaafException(rand())); + + byte[] object = EaafSerializationUtils.serialize(dummyC); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class), DummyClassA.class)); + + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class), DummyClassC.class)); + + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class, + Throwable.class, StackTraceElement[].class, + StackTraceElement.class, Collections.EMPTY_LIST.getClass()), DummyClassA.class)); + + } + + @Test + public void typeModeComplexNullObject() { + + DummyClassC dummyC = new DummyClassC(rand()); + DummyClassD dummyD = new DummyClassD(rand()); + dummyC.setMemberB(dummyD); + + byte[] object = EaafSerializationUtils.serialize(dummyC); + + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(DummyClassA.class, DummyClassB.class), DummyClassA.class)); + + } + + @Test + public void typeModeWithExceptions() { + EaafException error1 = new EaafException(rand()); + EaafAuthenticationException error2 = new EaafAuthenticationException(rand(), null, error1); + + byte[] object = EaafSerializationUtils.serialize(error2); + + // check if less allowed classes throw a deserialization exception + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(Throwable.class), Throwable.class)); + assertThrows(IllegalArgumentException.class, () -> EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(Throwable.class, StackTraceElement[].class, StackTraceElement.class), + Throwable.class)); + + // minimum allowed classes to de-serialize an Throwable + assertNotNull(EaafSerializationUtils.typeSpecificDeserialize( + object, Sets.newHashSet(Throwable.class, StackTraceElement[].class, + StackTraceElement.class, Collections.EMPTY_LIST.getClass()), + Throwable.class)); + + } + + private String rand() { + return RandomStringUtils.randomAlphanumeric(10); + } + + @Getter + @RequiredArgsConstructor + public static class DummyClassA implements Serializable { + + private static final long serialVersionUID = 1L; + + private final String test; + + @Setter + private DummyClassB memberB; + + } + + @Getter + @AllArgsConstructor + public static class DummyClassB implements Serializable { + + private static final long serialVersionUID = 1L; + + private final String test; + + } + + @Getter + public static class DummyClassD extends DummyClassB implements Serializable { + + private static final long serialVersionUID = 1L; + + public DummyClassD(String data) { + super(data); + } + + @Setter + private Object anyType; + + } + + @Getter + public static class DummyClassC extends DummyClassA implements Serializable { + + private static final long serialVersionUID = 1L; + + public DummyClassC(String data) { + super(data); + + this.test = data; + } + + private final String test; + + } + +} -- cgit v1.2.3