flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sihuazhou <...@git.apache.org>
Subject [GitHub] flink pull request #6196: [FLINK-9513] Implement TTL state wrappers factory ...
Date Fri, 22 Jun 2018 03:37:08 GMT
Github user sihuazhou commented on a diff in the pull request:

    https://github.com/apache/flink/pull/6196#discussion_r197331764
  
    --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
---
    @@ -0,0 +1,207 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.runtime.state.ttl;
    +
    +import org.apache.flink.api.common.state.AggregatingStateDescriptor;
    +import org.apache.flink.api.common.state.FoldingStateDescriptor;
    +import org.apache.flink.api.common.state.ListStateDescriptor;
    +import org.apache.flink.api.common.state.MapStateDescriptor;
    +import org.apache.flink.api.common.state.ReducingStateDescriptor;
    +import org.apache.flink.api.common.state.State;
    +import org.apache.flink.api.common.state.StateDescriptor;
    +import org.apache.flink.api.common.state.ValueStateDescriptor;
    +import org.apache.flink.api.common.typeutils.CompositeSerializer;
    +import org.apache.flink.api.common.typeutils.TypeSerializer;
    +import org.apache.flink.api.common.typeutils.base.LongSerializer;
    +import org.apache.flink.api.java.tuple.Tuple2;
    +import org.apache.flink.runtime.state.KeyedStateFactory;
    +import org.apache.flink.util.FlinkRuntimeException;
    +
    +import java.util.Arrays;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.stream.Collectors;
    +import java.util.stream.Stream;
    +
    +/**
    + * This state factory wraps state objects, produced by backends, with TTL logic.
    + */
    +public class TtlStateFactory {
    +	public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc,
    +		KeyedStateFactory originalStateFactory,
    +		TtlConfig ttlConfig,
    +		TtlTimeProvider timeProvider) throws Exception {
    +		return ttlConfig.getTtlUpdateType() == TtlUpdateType.Disabled ?
    +			originalStateFactory.createState(namespaceSerializer, stateDesc) :
    +			new TtlStateFactory(originalStateFactory, ttlConfig, timeProvider)
    +				.createState(namespaceSerializer, stateDesc);
    +	}
    +
    +	private final Map<Class<? extends StateDescriptor>, StateFactory> stateFactories;
    +
    +	private final KeyedStateFactory originalStateFactory;
    +	private final TtlConfig ttlConfig;
    +	private final TtlTimeProvider timeProvider;
    +
    +	private TtlStateFactory(KeyedStateFactory originalStateFactory, TtlConfig ttlConfig,
TtlTimeProvider timeProvider) {
    +		this.originalStateFactory = originalStateFactory;
    +		this.ttlConfig = ttlConfig;
    +		this.timeProvider = timeProvider;
    +		this.stateFactories = createStateFactories();
    +	}
    +
    +	private Map<Class<? extends StateDescriptor>, StateFactory> createStateFactories()
{
    +		return Stream.of(
    +			Tuple2.of(ValueStateDescriptor.class, (StateFactory) this::createValueState),
    +			Tuple2.of(ListStateDescriptor.class, (StateFactory) this::createListState),
    +			Tuple2.of(MapStateDescriptor.class, (StateFactory) this::createMapState),
    +			Tuple2.of(ReducingStateDescriptor.class, (StateFactory) this::createReducingState),
    +			Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) this::createAggregatingState),
    +			Tuple2.of(FoldingStateDescriptor.class, (StateFactory) this::createFoldingState)
    +		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
    +	}
    +
    +	private interface StateFactory {
    +		<N, SV, S extends State, IS extends S> IS create(
    +			TypeSerializer<N> namespaceSerializer,
    +			StateDescriptor<S, SV> stateDesc) throws Exception;
    +	}
    +
    +	private <N, SV, S extends State, IS extends S> IS createState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		StateFactory stateFactory = stateFactories.get(stateDesc.getClass());
    +		if (stateFactory == null) {
    +			String message = String.format("State %s is not supported by %s",
    +				stateDesc.getClass(), TtlStateFactory.class);
    +			throw new FlinkRuntimeException(message);
    +		}
    +		return stateFactory.create(namespaceSerializer, stateDesc);
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <N, SV, S extends State, IS extends S> IS createValueState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		SV defVal = stateDesc.getDefaultValue();
    +		TtlValue<SV> ttlDefVal = defVal == null ? null : new TtlValue<>(defVal,
Long.MAX_VALUE);
    +		ValueStateDescriptor<TtlValue<SV>> ttlDescriptor = new ValueStateDescriptor<>(
    +			stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()), ttlDefVal);
    +		return (IS) new TtlValueState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, stateDesc.getSerializer());
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <T, N, SV, S extends State, IS extends S> IS createListState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		ListStateDescriptor<T> listStateDesc = (ListStateDescriptor<T>) stateDesc;
    +		ListStateDescriptor<TtlValue<T>> ttlDescriptor = new ListStateDescriptor<>(
    +			stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer()));
    +		return (IS) new TtlListState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, listStateDesc.getSerializer());
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <UK, UV, N, SV, S extends State, IS extends S> IS createMapState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		MapStateDescriptor<UK, UV> mapStateDesc = (MapStateDescriptor<UK, UV>)
stateDesc;
    +		MapStateDescriptor<UK, TtlValue<UV>> ttlDescriptor = new MapStateDescriptor<>(
    +			stateDesc.getName(),
    +			mapStateDesc.getKeySerializer(),
    +			new TtlSerializer<>(mapStateDesc.getValueSerializer()));
    +		return (IS) new TtlMapState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, mapStateDesc.getSerializer());
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <N, SV, S extends State, IS extends S> IS createReducingState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		ReducingStateDescriptor<SV> reducingStateDesc = (ReducingStateDescriptor<SV>)
stateDesc;
    +		ReducingStateDescriptor<TtlValue<SV>> ttlDescriptor = new ReducingStateDescriptor<>(
    +			stateDesc.getName(),
    +			new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider),
    +			new TtlSerializer<>(stateDesc.getSerializer()));
    +		return (IS) new TtlReducingState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, stateDesc.getSerializer());
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <IN, OUT, N, SV, S extends State, IS extends S> IS createAggregatingState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		AggregatingStateDescriptor<IN, SV, OUT> aggregatingStateDescriptor =
    +			(AggregatingStateDescriptor<IN, SV, OUT>) stateDesc;
    +		TtlAggregateFunction<IN, SV, OUT> ttlAggregateFunction = new TtlAggregateFunction<>(
    +			aggregatingStateDescriptor.getAggregateFunction(), ttlConfig, timeProvider);
    +		AggregatingStateDescriptor<IN, TtlValue<SV>, OUT> ttlDescriptor = new AggregatingStateDescriptor<>(
    +			stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer()));
    +		return (IS) new TtlAggregatingState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction);
    +	}
    +
    +	@SuppressWarnings("unchecked")
    +	private <T, N, SV, S extends State, IS extends S> IS createFoldingState(
    +		TypeSerializer<N> namespaceSerializer,
    +		StateDescriptor<S, SV> stateDesc) throws Exception {
    +		FoldingStateDescriptor<T, SV> foldingStateDescriptor = (FoldingStateDescriptor<T,
SV>) stateDesc;
    +		SV initAcc = stateDesc.getDefaultValue();
    +		TtlValue<SV> ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc,
Long.MAX_VALUE);
    +		FoldingStateDescriptor<T, TtlValue<SV>> ttlDescriptor = new FoldingStateDescriptor<>(
    +			stateDesc.getName(),
    +			ttlInitAcc,
    +			new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider),
    +			new TtlSerializer<>(stateDesc.getSerializer()));
    +		return (IS) new TtlFoldingState<>(
    +			originalStateFactory.createState(namespaceSerializer, ttlDescriptor),
    +			ttlConfig, timeProvider, stateDesc.getSerializer());
    +	}
    +
    +	private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>>
{
    +		TtlSerializer(TypeSerializer<T> userValueSerializer) {
    +			super(Arrays.asList(userValueSerializer, new LongSerializer()));
    +		}
    +
    +		@Override
    +		@SuppressWarnings("unchecked")
    +		protected TtlValue<T> composeValue(List values) {
    +			return new TtlValue<>((T) values.get(0), (Long) values.get(1));
    +		}
    +
    +		@Override
    +		protected List decomposeValue(TtlValue<T> v) {
    +			return Arrays.asList(v.getUserValue(), v.getExpirationTimestamp());
    +		}
    +
    +		@Override
    +		@SuppressWarnings("unchecked")
    +		protected CompositeSerializer<TtlValue<T>> createSerializerInstance(List<TypeSerializer>
typeSerializers) {
    --- End diff --
    
    Should we check that `typeSerializers != null && typeSerializers.size() == 1`?


---

Mime
View raw message