2

JPA Association Fetching Validator

 2 years ago
source link: https://vladmihalcea.com/jpa-association-fetching-validator/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Last modified:

If you are trading with Kraken or Revolut, then you are going to love RevoGain!

Introduction

In this article, I’m going to show you how we can build a JPA Association Fetching Validator that asserts whether JPA and Hibernate associations are fetched using joins or secondary queries.

While Hibernate does not provide built-in support for checking the entity association fetching behavior programmatically, the API is very flexible and allows us to customize it so that we can achieve this non-trivial requirement.

Domain Model

Let’s assume we have the following Post, PostComment, and PostCommentDetails entities:

JPA Association Fetching Validator Entities

The Post parent entity looks as follows:

@Entity(name = "Post")
@Table(name = "post")
public class Post {
@Id
private Long id;
private String title;
//Getters and setters omitted for brevity
}

Next, we define the PostComment child entity, like this:

@Entity(name = "PostComment")
@Table(name = "post_comment")
public class PostComment {
@Id
private Long id;
@ManyToOne
private Post post;
private String review;
//Getters and setters omitted for brevity
}

Notice that the post association uses the default fetch strategy provided by the @ManyToOne association, the infamous FetchType.EAGER strategy that’s responsible for causing lots of performance problems, as explained in this article.

And the PostCommentDetails child entity defines a one-to-one child association to the PostComment parent entity. And again, the comment association uses the default FetchType.EAGER fetching strategy.

@Entity(name = "PostCommentDetails")
@Table(name = "post_comment_details")
public class PostCommentDetails {
@Id
private Long id;
@OneToOne
@MapsId
@OnDelete(action = OnDeleteAction.CASCADE)
private PostComment comment;
private int votes;
//Getters and setters omitted for brevity
}

The problem of FetchType.EAGER strategy

So, we have two associations using the FetchType.EAGER anti-pattern. therefore, when executing the following JPQL query:

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();

Hibernate executes the following 3 SQL queries:

SELECT
pce.comment_id AS comment_2_2_,
pce.votes AS votes1_2_
FROM
post_comment_details pce
ORDER BY
pce.comment_id
SELECT
pc.id AS id1_1_0_,
pc.post_id AS post_id3_1_0_,
pc.review AS review2_1_0_,
p.id AS id1_0_1_,
p.title AS title2_0_1_
FROM
post_comment pc
LEFT OUTER JOIN
post p ON pc.post_id=p.id
WHERE
pc.id = 1
SELECT
pc.id AS id1_1_0_,
pc.post_id AS post_id3_1_0_,
pc.review AS review2_1_0_,
p.id AS id1_0_1_,
p.title AS title2_0_1_
FROM
post_comment pc
LEFT OUTER JOIN
post p ON pc.post_id=p.id
WHERE
pc.id = 2

This is a classic N+1 query issue. However, not only are extra secondary queries executed to fetch the PostComment associations, but these queries are using JOINs to fetch the associated Post entity as well.

Unless you want to load the entire database with a single query, it’s best to avoid using the FetchType.EAGER anti-pattern.

So, let’s see if we can detect these extra secondary queries and JOINs programmatically.

Hibernate Statistics to detect secondary queries

As I explained in this article, not only can Hibernate collect statistical information, but we can even customize the data that gets collected.

For instance, we could monitor how many entities have been fetched per Session using the following SessionStatistics utility:

public class SessionStatistics extends StatisticsImpl {
private static final ThreadLocal<Map<Class, AtomicInteger>>
entityFetchCountContext = new ThreadLocal<>();
public SessionStatistics(
SessionFactoryImplementor sessionFactory) {
super(sessionFactory);
}
@Override
public void openSession() {
entityFetchCountContext.set(new LinkedHashMap<>());
super.openSession();
}
@Override
public void fetchEntity(
String entityName) {
Map<Class, AtomicInteger> entityFetchCountMap = entityFetchCountContext
.get();
entityFetchCountMap
.computeIfAbsent(
ReflectionUtils.getClass(entityName),
clazz -> new AtomicInteger()
)
.incrementAndGet();
super.fetchEntity(entityName);
}
@Override
public void closeSession() {
entityFetchCountContext.remove();
super.closeSession();
}
public static int getEntityFetchCount(
String entityClassName) {       
return getEntityFetchCount(
ReflectionUtils.getClass(entityClassName)
);
}
public static int getEntityFetchCount(
Class entityClass) {
AtomicInteger entityFetchCount = entityFetchCountContext.get()
.get(entityClass);
return entityFetchCount != null ? entityFetchCount.get() : 0;
}
public static class Factory implements StatisticsFactory {
public static final Factory INSTANCE = new Factory();
@Override
public StatisticsImplementor buildStatistics(
SessionFactoryImplementor sessionFactory) {
return new SessionStatistics(sessionFactory);
}
}
}

The SessionStatistics class extends the default Hibernate StatisticsImpl class and overrides the following methods:

  • openSession – this callback method is called when a Hibernate Session is created for the first time. We are using this callback to initialize the ThreadLocal storage that contains the entity fetching registry.
  • fetchEntity – this callback is called whenever an entity is fetched from the database using a secondary query. And we use this callback to increase the entity fetching counter.
  • closeSession – this callback method is called when a Hibernate Session is closed. In our case, this is when we need to reset the ThreadLocal storage.

The getEntityFetchCount method will allow us to inspect how many entity instances have been fetched from the database for a given entity class.

The Factory nested class implements the StatisticsFactory interface and implements the buildStatistics method, which is called by the SessionFactory at bootstrap time.

To configure Hibernate to use the custom SessionStatistics, we have to provide the following two setting:

properties.put(
AvailableSettings.GENERATE_STATISTICS,
Boolean.TRUE.toString()
);
properties.put(
StatisticsInitiator.STATS_BUILDER,
SessionStatistics.Factory.INSTANCE
);

The first one activates the Hibernate statistics mechanism while the second one tells Hibernate to use a custom StatisticsFactory.

So, let’s see it in action!

assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(2, commentDetailsList.size());
assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(2, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));

So, the SessionStatistics can only help us to determine the extra secondary queries, but it does not work for extra JOINs that are executed because of FetchType.EAGER associations.

Hibernate Event Listeners to detect both secondary queries and extra JOINs

Fortunately for us, Hibernate is extremely customizable since, internally, it’s built on top of the Observer pattern.

Every entity action generates an event that’s handled by an event listener, and we can use this mechanism to monitor the entity fetching behavior.

When an entity is fetched directly using the find method or via a query, a LoadEvent is going to be triggered. The LoadEvent is handled first by the LoadEventListener and PostLoadEventListener Hibernate event handlers.

While Hibernate provides default event handlers for all entity events, we can also prepend or append our own listeners using an Integrator, like the following one:

public class AssociationFetchingEventListenerIntegrator
implements Integrator {
public static final AssociationFetchingEventListenerIntegrator INSTANCE =
new AssociationFetchingEventListenerIntegrator();
@Override
public void integrate(
Metadata metadata,
SessionFactoryImplementor sessionFactory,
SessionFactoryServiceRegistry serviceRegistry) {
final EventListenerRegistry eventListenerRegistry =
serviceRegistry.getService(EventListenerRegistry.class);
eventListenerRegistry.prependListeners(
EventType.LOAD,
AssociationFetchPreLoadEventListener.INSTANCE
);
eventListenerRegistry.appendListeners(
EventType.LOAD,
AssociationFetchLoadEventListener.INSTANCE
);
eventListenerRegistry.appendListeners(
EventType.POST_LOAD,
AssociationFetchPostLoadEventListener.INSTANCE
);
}
@Override
public void disintegrate(
SessionFactoryImplementor sessionFactory,
SessionFactoryServiceRegistry serviceRegistry) {
}
}

Our AssociationFetchingEventListenerIntegrator registers three extra event listeners:

  • An AssociationFetchPreLoadEventListener that is executed before the default Hibernate LoadEventListener
  • An AssociationFetchLoadEventListener that is executed after the default Hibernate LoadEventListener
  • And an AssociationFetchPostLoadEventListener that is executed after the default Hibernate PostLoadEventListener

To instruct Hibernate to use our custom AssociationFetchingEventListenerIntegrator that registers the extra event listeners, we just have to set the hibernate.integrator_provider configuration property:

properties.put(
"hibernate.integrator_provider",
(IntegratorProvider) () -> Collections.singletonList(
AssociationFetchingEventListenerIntegrator.INSTANCE
)
);

The AssociationFetchPreLoadEventListener implements the LoadEventListener interface and looks like this:

public class AssociationFetchPreLoadEventListener
implements LoadEventListener {
public static final AssociationFetchPreLoadEventListener INSTANCE =
new AssociationFetchPreLoadEventListener();
@Override
public void onLoad(
LoadEvent event,
LoadType loadType) {
AssociationFetch.Context
.get(event.getSession())
.preLoad(event);
}
}

The AssociationFetchLoadEventListener also implements the LoadEventListener interface and looks as follows:

public class AssociationFetchLoadEventListener
implements LoadEventListener {
public static final AssociationFetchLoadEventListener INSTANCE =
new AssociationFetchLoadEventListener();
@Override
public void onLoad(
LoadEvent event,
LoadType loadType) {
AssociationFetch.Context
.get(event.getSession())
.load(event);
}
}

And, the AssociationFetchPostLoadEventListener implements the PostLoadEventListener interface and looks like this:

public class AssociationFetchPostLoadEventListener
implements PostLoadEventListener {
public static final AssociationFetchPostLoadEventListener INSTANCE =
new AssociationFetchPostLoadEventListener();
@Override
public void onPostLoad(
PostLoadEvent event) {
AssociationFetch.Context
.get(event.getSession())
.postLoad(event);
}
}

Notice that all the entity fetching monitoring logic is encapsulated in the following AssociationFetch class:

public class AssociationFetch {
private final Object entity;
public AssociationFetch(Object entity) {
this.entity = entity;
}
public Object getEntity() {
return entity;
}
public static class Context implements Serializable {
public static final String SESSION_PROPERTY_KEY = "ASSOCIATION_FETCH_LIST";
private Map<String, Integer> entityFetchCountByClassNameMap =
new LinkedHashMap<>();
private Set<EntityIdentifier> joinedFetchedEntities =
new LinkedHashSet<>();
private Set<EntityIdentifier> secondaryFetchedEntities =
new LinkedHashSet<>();
private Map<EntityIdentifier, Object> loadedEntities =
new LinkedHashMap<>();
public List<AssociationFetch> getAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(joinedFetchedEntities.contains(entityIdentifier) ||
secondaryFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public List<AssociationFetch> getJoinedAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(joinedFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public List<AssociationFetch> getSecondaryAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(secondaryFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public Map<Class, List<Object>> getAssociationFetchEntityMap() {
return getAssociationFetches()
.stream()
.map(AssociationFetch::getEntity)
.collect(groupingBy(Object::getClass));
}
public void preLoad(LoadEvent loadEvent) {
String entityClassName = loadEvent.getEntityClassName();
entityFetchCountByClassNameMap.put(
entityClassName,
SessionStatistics.getEntityFetchCount(
entityClassName
)
);
}
public void load(LoadEvent loadEvent) {
String entityClassName = loadEvent.getEntityClassName();
int previousFetchCount = entityFetchCountByClassNameMap.get(
entityClassName
);
int currentFetchCount = SessionStatistics.getEntityFetchCount(
entityClassName
);
EntityIdentifier entityIdentifier = new EntityIdentifier(
ReflectionUtils.getClass(loadEvent.getEntityClassName()),
loadEvent.getEntityId()
);
if (loadEvent.isAssociationFetch()) {
if (currentFetchCount == previousFetchCount) {
joinedFetchedEntities.add(entityIdentifier);
} else if (currentFetchCount > previousFetchCount){
secondaryFetchedEntities.add(entityIdentifier);
}
}
}
public void postLoad(PostLoadEvent postLoadEvent) {
loadedEntities.put(
new EntityIdentifier(
postLoadEvent.getEntity().getClass(),
postLoadEvent.getId()
),
postLoadEvent.getEntity()
);
}
public static Context get(Session session) {
Context context = (Context) session.getProperties()
.get(SESSION_PROPERTY_KEY);
if (context == null) {
context = new Context();
session.setProperty(SESSION_PROPERTY_KEY, context);
}
return context;
}
public static Context get(EntityManager entityManager) {
return get(entityManager.unwrap(Session.class));
}
}
private static class EntityIdentifier {
private final Class entityClass;
private final Serializable entityId;
public EntityIdentifier(Class entityClass, Serializable entityId) {
this.entityClass = entityClass;
this.entityId = entityId;
}
public Class getEntityClass() {
return entityClass;
}
public Serializable getEntityId() {
return entityId;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof EntityIdentifier)) return false;
EntityIdentifier that = (EntityIdentifier) o;
return Objects.equals(getEntityClass(), that.getEntityClass()) &&
Objects.equals(getEntityId(), that.getEntityId());
}
@Override
public int hashCode() {
return Objects.hash(getEntityClass(), getEntityId());
}
}
}

And, that’s it!

Testing Time

So, let’s see how this new utility works. When running the same query that was used at the beginning of this article, we can see that we can now capture all the association fetches that were done while executing the JPQL query:

AssociationFetch.Context context = AssociationFetch.Context.get(
entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(3, context.getAssociationFetches().size());
assertEquals(2, context.getSecondaryAssociationFetches().size());
assertEquals(1, context.getJoinedAssociationFetches().size());
Map<Class, List<Object>> associationFetchMap = context
.getAssociationFetchEntityMap();
assertEquals(2, associationFetchMap.size());
for (PostCommentDetails commentDetails : commentDetailsList) {
assertTrue(
associationFetchMap.get(PostComment.class)
.contains(commentDetails.getComment())
);
assertTrue(
associationFetchMap.get(Post.class)
.contains(commentDetails.getComment().getPost())
);
}

The tool tells us that 3 more entities are fetched by that query:

  • 2 PostComment entities using two secondary queries
  • one Post entity that’s fetched using a JOIN clause by the secondary queries

If we rewrite the previous query to use JOIN FETCH instead for all these 3 associations:

AssociationFetch.Context context = AssociationFetch.Context.get(
entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
join fetch pcd.comment pc
join fetch pc.post
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(3, context.getJoinedAssociationFetches().size());
assertTrue(context.getSecondaryAssociationFetches().isEmpty());

We can see that, indeed, no secondary SQL query is executed this time, and the 3 associations are fetched using JOIN clauses.

Cool, right?

I'm running an online workshop on the 4th of May about SQL Window Functions.

If you enjoyed this article, I bet you are going to love my Book and Video Courses as well.

Conclusion

Building a JPA Association Fetching Validator can be done just fine using the Hibernate ORM since the API provides many extension points.

If you like this JPA Association Fetching Validator tool, then you are going to love Hypersistence Optizier, which promises tens of checks and validations so that you can get the most out of your Spring Boot or Jakarta EE application.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK