Skip to content

Commit

Permalink
refactor(citrus-3.0): Fix reference resolver usage in receive message…
Browse files Browse the repository at this point in the history
… action builder

We need to makes sure to use the reference resolver at build time only. Otherwise it might not be initialized in the builder yet
  • Loading branch information
christophd committed Mar 9, 2020
1 parent 983748f commit 9f4918b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import org.springframework.core.io.Resource;
import org.springframework.oxm.Marshaller;
import org.springframework.oxm.XmlMappingException;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.xml.transform.StringResult;
Expand Down Expand Up @@ -456,7 +455,7 @@ public ReceiveMessageAction doBuild() {
}
}

public static abstract class ReceiveMessageActionBuilder<T extends ReceiveMessageAction, B extends ReceiveMessageActionBuilder<T, B>> extends AbstractTestActionBuilder<T, B> {
public static abstract class ReceiveMessageActionBuilder<T extends ReceiveMessageAction, B extends ReceiveMessageActionBuilder<T, B>> extends AbstractTestActionBuilder<T, B> implements ReferenceResolverAware {
private Endpoint endpoint;
private String endpointUri;
private long receiveTimeout = 0L;
Expand All @@ -465,6 +464,7 @@ public static abstract class ReceiveMessageActionBuilder<T extends ReceiveMessag
private AbstractMessageContentBuilder messageBuilder = new PayloadTemplateMessageBuilder();
private List<MessageValidator<? extends ValidationContext>> validators = new ArrayList<>();
private DataDictionary<?> dataDictionary;
private String dataDictionaryName;
private ValidationCallback validationCallback;
private List<ValidationContext> validationContexts = new ArrayList<>();
private List<VariableExtractor> variableExtractors = new ArrayList<>();
Expand All @@ -486,6 +486,11 @@ public static abstract class ReceiveMessageActionBuilder<T extends ReceiveMessag
private XpathPayloadVariableExtractor xpathExtractor;
private JsonPathVariableExtractor jsonPathExtractor;

private final List<String> validatorNames = new ArrayList<>();
private final List<String> headerValidatorNames = new ArrayList<>();
private final Map<String, List<Object>> headerFragmentMappers = new HashMap<>();
private final Map<String, List<Object>> payloadMappers = new HashMap<>();

/**
* Basic bean reference resolver.
*/
Expand Down Expand Up @@ -662,15 +667,9 @@ public B payload(final Object payload, final ObjectMapper objectMapper) {
* @return
*/
public B payloadModel(final Object payload) {
validateApplicationContext();

if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(Marshaller.class))) {
return payload(payload, referenceResolver.resolve(Marshaller.class));
} else if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(ObjectMapper.class))) {
return payload(payload, referenceResolver.resolve(ObjectMapper.class));
}

throw createUnableToFindMapperException();
this.payloadMappers.putIfAbsent("", new ArrayList<>());
this.payloadMappers.get("").add(payload);
return self;
}

/**
Expand All @@ -682,21 +681,9 @@ public B payloadModel(final Object payload) {
* @return
*/
public B payload(final Object payload, final String mapperName) {
validateApplicationContext();

if (referenceResolver.isResolvable(mapperName)) {
final Object mapper = referenceResolver.resolve(mapperName);

if (Marshaller.class.isAssignableFrom(mapper.getClass())) {
return payload(payload, (Marshaller) mapper);
} else if (ObjectMapper.class.isAssignableFrom(mapper.getClass())) {
return payload(payload, (ObjectMapper) mapper);
} else {
throw new CitrusRuntimeException(String.format("Invalid bean type for mapper '%s' expected ObjectMapper or Marshaller but was '%s'", mapperName, mapper.getClass()));
}
}

throw createUnableToFindMapperException();
this.payloadMappers.putIfAbsent(mapperName, new ArrayList<>());
this.payloadMappers.get(mapperName).add(payload);
return self;
}

/**
Expand Down Expand Up @@ -742,15 +729,9 @@ public B header(final String data) {
* @return
*/
public B headerFragment(final Object model) {
validateApplicationContext();

if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(Marshaller.class))) {
return headerFragment(model, referenceResolver.resolve(Marshaller.class));
} else if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(ObjectMapper.class))) {
return headerFragment(model, referenceResolver.resolve(ObjectMapper.class));
}

throw createUnableToFindMapperException();
this.headerFragmentMappers.putIfAbsent("", new ArrayList<>());
this.headerFragmentMappers.get("").add(model);
return self;
}

/**
Expand All @@ -762,21 +743,9 @@ public B headerFragment(final Object model) {
* @return
*/
public B headerFragment(final Object model, final String mapperName) {
validateApplicationContext();

if (referenceResolver.isResolvable(mapperName)) {
final Object mapper = referenceResolver.resolve(mapperName);

if (Marshaller.class.isAssignableFrom(mapper.getClass())) {
return headerFragment(model, (Marshaller) mapper);
} else if (ObjectMapper.class.isAssignableFrom(mapper.getClass())) {
return headerFragment(model, (ObjectMapper) mapper);
} else {
throw new CitrusRuntimeException(String.format("Invalid bean type for mapper '%s' expected ObjectMapper or Marshaller but was '%s'", mapperName, mapper.getClass()));
}
}

throw createUnableToFindMapperException();
this.headerFragmentMappers.putIfAbsent(mapperName, new ArrayList<>());
this.headerFragmentMappers.get(mapperName).add(model);
return self;
}

/**
Expand Down Expand Up @@ -1195,12 +1164,7 @@ public B validators(final List<MessageValidator<? extends ValidationContext>> va
*/
@SuppressWarnings("unchecked")
public B validator(final String... validatorNames) {
validateApplicationContext();

for (final String validatorName : validatorNames) {
this.validators.add(referenceResolver.resolve(validatorName, MessageValidator.class));
}

this.validatorNames.addAll(Arrays.asList(validatorNames));
return self;
}

Expand All @@ -1222,12 +1186,7 @@ public B headerValidator(final HeaderValidator... validators) {
* @return
*/
public B headerValidator(final String... validatorNames) {
validateApplicationContext();

for (final String validatorName : validatorNames) {
getHeaderValidationContext().addHeaderValidator(referenceResolver.resolve(validatorName, HeaderValidator.class));
}

this.headerValidatorNames.addAll(Arrays.asList(validatorNames));
return self;
}

Expand All @@ -1249,8 +1208,7 @@ public B dictionary(final DataDictionary<?> dictionary) {
* @return
*/
public B dictionary(final String dictionaryName) {
validateApplicationContext();
this.dataDictionary = referenceResolver.resolve(dictionaryName, DataDictionary.class);
this.dataDictionaryName = dictionaryName;
return self;
}

Expand Down Expand Up @@ -1296,10 +1254,6 @@ public B extractFromPayload(final String path, final String variable) {
* @return
*/
public B validationCallback(final ValidationCallback callback) {
if (callback instanceof ReferenceResolverAware) {
((ReferenceResolverAware) callback).setReferenceResolver(referenceResolver);
}

this.validationCallback = callback;
return self;
}
Expand All @@ -1324,12 +1278,93 @@ public B withReferenceResolver(final ReferenceResolver referenceResolver) {
return self;
}

/**
* Specifies the referenceResolver.
*
* @param referenceResolver
*/
@Override
public void setReferenceResolver(ReferenceResolver referenceResolver) {
this.referenceResolver = referenceResolver;
}

@Override
public final T build() {
reconcileValidationContexts();

if (referenceResolver != null) {
if (validationCallback != null &&
validationCallback instanceof ReferenceResolverAware) {
((ReferenceResolverAware) validationCallback).setReferenceResolver(referenceResolver);
}

for (final String validatorName : validatorNames) {
this.validators.add(referenceResolver.resolve(validatorName, MessageValidator.class));
}

for (final String validatorName : headerValidatorNames) {
getHeaderValidationContext().addHeaderValidator(referenceResolver.resolve(validatorName, HeaderValidator.class));
}

if (dataDictionaryName != null) {
this.dataDictionary = referenceResolver.resolve(dataDictionaryName, DataDictionary.class);
}

for (Map.Entry<String, List<Object>> mapperEntry : headerFragmentMappers.entrySet()) {
String mapperName = mapperEntry.getKey();
final Object mapper = findMapperOrMarshaller(mapperName);

for (Object model : mapperEntry.getValue()) {
if (Marshaller.class.isAssignableFrom(mapper.getClass())) {
headerFragment(model, (Marshaller) mapper);
} else if (ObjectMapper.class.isAssignableFrom(mapper.getClass())) {
headerFragment(model, (ObjectMapper) mapper);
} else {
throw new CitrusRuntimeException(String.format("Invalid bean type for mapper '%s' expected ObjectMapper or Marshaller but was '%s'", mapperName, mapper.getClass()));
}
}
}

for (Map.Entry<String, List<Object>> mapperEntry : payloadMappers.entrySet()) {
String mapperName = mapperEntry.getKey();
final Object mapper = findMapperOrMarshaller(mapperName);

for (Object model : mapperEntry.getValue()) {
if (Marshaller.class.isAssignableFrom(mapper.getClass())) {
payload(model, (Marshaller) mapper);
} else if (ObjectMapper.class.isAssignableFrom(mapper.getClass())) {
payload(model, (ObjectMapper) mapper);
} else {
throw new CitrusRuntimeException(String.format("Invalid bean type for mapper '%s' expected ObjectMapper or Marshaller but was '%s'", mapperName, mapper.getClass()));
}
}
}
}

return doBuild();
}

/**
* Find mapper or marshaller for given name using the reference resolver in this builder.
* @param mapperName
* @return
*/
private Object findMapperOrMarshaller(String mapperName) {
if (mapperName.equals("")) {
if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(Marshaller.class))) {
return referenceResolver.resolve(Marshaller.class);
} else if (!CollectionUtils.isEmpty(referenceResolver.resolveAll(ObjectMapper.class))) {
return referenceResolver.resolve(ObjectMapper.class);
} else {
throw createUnableToFindMapperException();
}
} else if (referenceResolver.isResolvable(mapperName)) {
return referenceResolver.resolve(mapperName);
} else {
throw createUnableToFindMapperException();
}
}

/**
* Build method implemented by subclasses.
* @return
Expand Down Expand Up @@ -1471,10 +1506,6 @@ private CitrusRuntimeException createUnableToFindMapperException() {
return new CitrusRuntimeException("Unable to resolve default object mapper or marshaller");
}

private void validateApplicationContext() {
Assert.notNull(referenceResolver, "Citrus bean reference resolver is not initialized!");
}

/**
* Revisit configured validation context list and automatically add context based on message payload and path
* expression contexts if any.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1281,13 +1281,13 @@ void headerValidator_fromNames() {
doReturn(validator3).when(referenceResolver).resolve(name3, HeaderValidator.class);
ReflectionTestUtils.setField(builder, "referenceResolver", referenceResolver);


//WHEN
final ReceiveMessageAction.Builder copy = builder.headerValidator(name1, name2, name3);

//THEN
assertSame(copy, builder);

builder.build();
final HeaderValidationContext headerValidationContext =
getFieldFromBuilder(builder, HeaderValidationContext.class, "headerValidationContext");
assertEquals(3, headerValidationContext.getValidators().size());
Expand Down Expand Up @@ -1455,10 +1455,10 @@ void testSetMessageTypeAsString(){
}

private <T> T getFieldFromBuilder(ReceiveMessageAction.Builder builder, final Class<T> targetClass, final String fieldName) {
final T scriptValidationContext = targetClass.cast(
final T validationContext = targetClass.cast(
ReflectionTestUtils.getField(builder, fieldName));
assertNotNull(scriptValidationContext);
return scriptValidationContext;
assertNotNull(validationContext);
return validationContext;
}

private String getPayloadData(ReceiveMessageAction.Builder builder) {
Expand Down

0 comments on commit 9f4918b

Please sign in to comment.