构建基于 Flink 实时特征流与 LangChain 的动态上下文 API 网关


我们基于 LangChain 构建的初代智能客服 API 已经上线,但很快就暴露了一个根本性问题:缺乏实时上下文。无论用户前一秒在我们的平台上做了什么,API 返回的回答总是千篇一律,它对用户的“当下”一无所知。业务方要求模型能够感知用户最近几分钟内的行为,例如,如果用户连续三次查询同一类商品失败,下一次对话应该主动提供帮助,而不是机械地回答“请问有什么可以帮您”。

传统的解决方案是查询业务数据库,但这对于需要捕捉时间窗口内(例如“过去5分钟”)行为序列的场景来说,不仅查询复杂、性能低下,而且对主业务数据库造成了巨大压力。我们需要一个能从用户行为流中实时提炼特征、并将其注入到 LLM 调用上下文中的架构。这个问题的核心在于,如何低延迟地连接用户在边缘(API Gateway)产生的行为与中心化的大模型服务。

最初的构想是在 API Gateway 层面对请求进行甄别,并将特定行为日志异步发送到消息队列。一个流处理系统消费这些日志,计算出实时特征,存入一个高速缓存。最后,当请求到达 LangChain 后端服务时,该服务先从缓存中拉取实时特征,构建出丰富的上下文,再调用大语言模型。这套架构的关键在于流处理引擎的选择和整个数据链路的无缝整合。

在流处理引擎的选择上,我们评估了 Kafka Streams 和 Apache Flink。Kafka Streams 对于简单的无状态或小状态的流处理很方便,但我们的场景需要复杂的窗口计算(例如,过去5分钟内某类API的调用次数、失败率)和可靠的状态管理。Apache Flink 在这方面是天生的赢家。其强大的窗口 API、精确一次(Exactly-once)的状态一致性保证以及为大规模、低延迟场景设计的分布式架构,使其成为构建实时特征引擎的不二之选。

最终确定的技术栈如下:

  1. API Gateway: Spring Cloud Gateway。利用其 GlobalFilter 捕获所有请求,作为实时事件的源头。
  2. 消息队列: Apache Kafka。作为网关与 Flink 之间解耦和削峰填谷的缓冲层。
  3. 流处理引擎: Apache Flink。核心组件,负责消费 Kafka 数据,进行窗口聚合,计算实时用户特征。
  4. 实时特征存储: Redis。低延迟的键值存储,用于存放 Flink 计算出的特征,供下游应用查询。
  5. LLM 应用框架: LangChain(通过 Java 服务调用),整合实时特征,生成动态 Prompt。
  6. 核心框架: Spring Boot,用于构建网关和业务服务。

整个数据流转的架构如下:

graph TD
    A[用户请求] --> B[Spring Cloud Gateway];
    B --> C{GlobalFilter: 行为捕获};
    C -- 异步发送 --> D[Kafka Topic: user-actions];
    E[Apache Flink Job] -- 消费 --> D;
    E --> F{窗口计算 & 特征提取};
    F -- 写入 --> G[Redis: real-time-features];
    
    subgraph LLM 服务调用
        H[用户请求至智能服务] --> I[Java Backend Service];
        I -- 查询特征 --> G;
        I -- 组合 Prompt --> J[LangChain Service];
        J -- 调用 --> K[LLM];
        K --> J;
        J --> I;
        I --> L[响应];
    end
    
    B --> I;

第一步:在 Spring Cloud Gateway 捕获行为事件

我们需要创建一个 GlobalFilter,它会在每个请求被路由之前执行。这个 Filter 的核心任务是提取关键信息(如用户ID、请求路径、时间戳),并将其封装成一个事件对象,然后通过 KafkaTemplate 发送到 Kafka。

在真实项目中,直接同步发送事件到 Kafka 是不可接受的,它会阻塞请求处理线程,增加网关的响应延迟。因此,我们必须采用异步发送,并添加回调来处理可能出现的发送失败。

UserActionCaptureFilter.java

package com.example.gateway.filter;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.support.SendResult;
import org.springframework.stereotype.Component;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.io.Serializable;
import java.time.Instant;

@Slf4j
@Component
public class UserActionCaptureFilter implements GlobalFilter, Ordered {

    private final KafkaTemplate<String, String> kafkaTemplate;
    private final ObjectMapper objectMapper;
    private static final String USER_ID_HEADER = "X-User-ID";
    private static final String KAFKA_TOPIC = "user-actions";

    public UserActionCaptureFilter(KafkaTemplate<String, String> kafkaTemplate, ObjectMapper objectMapper) {
        this.kafkaTemplate = kafkaTemplate;
        this.objectMapper = objectMapper;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String userId = request.getHeaders().getFirst(USER_ID_HEADER);

        // 只捕获带有用户标识的请求
        if (userId != null && !userId.isEmpty()) {
            UserActionEvent event = new UserActionEvent(
                    userId,
                    request.getPath().value(),
                    request.getMethodValue(),
                    Instant.now().toEpochMilli()
            );

            try {
                String eventJson = objectMapper.writeValueAsString(event);
                // 异步发送,避免阻塞网关线程
                kafkaTemplate.send(KAFKA_TOPIC, event.getUserId(), eventJson)
                        .addCallback(new ListenableFutureCallback<>() {
                            @Override
                            public void onFailure(Throwable ex) {
                                log.error("Failed to send user action event to Kafka for user {}: {}", event.getUserId(), ex.getMessage());
                            }

                            @Override
                            public void onSuccess(SendResult<String, String> result) {
                                log.trace("Successfully sent user action event to Kafka partition {}", result.getRecordMetadata().partition());
                            }
                        });
            } catch (Exception e) {
                // JSON序列化失败等异常,仅记录日志,不影响主流程
                log.error("Error serializing or sending user action event", e);
            }
        }

        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        // 确保在路由过滤器之前执行
        return -1;
    }

    @Data
    @AllArgsConstructor
    public static class UserActionEvent implements Serializable {
        private String userId;
        private String path;
        private String method;
        private long timestamp;
    }
}

application.yml 配置

spring:
  application:
    name: api-gateway
  cloud:
    gateway:
      routes:
        - id: llm_service_route
          uri: lb://llm-service
          predicates:
            - Path=/api/v1/chat/**
  kafka:
    bootstrap-servers: localhost:9092
    producer:
      key-serializer: org.apache.kafka.common.serialization.StringSerializer
      value-serializer: org.apache.kafka.common.serialization.StringSerializer
      # 生产环境建议设置为 all,保证数据不丢失
      acks: 1
      properties:
        # 增加吞吐量
        linger.ms: 10
        batch.size: 16384

这个过滤器优雅地解决了事件源的问题。它对网关的性能影响极小,同时保证了事件数据的可靠投递。

这是整个架构的核心。Flink 作业需要从 Kafka 消费 UserActionEvent JSON 字符串,将其反序列化,然后按 userId 进行分组(keyBy),最后应用滑动窗口(TumblingEventTimeWindowsSlidingEventTimeWindows)来计算特征。

我们定义一个场景:计算用户在过去5分钟内,访问不同API端点的数量,以及总请求次数。

RealTimeFeatureEngineeringJob.java

package com.example.flink;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.serialization.SimpleStringSchema;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.connector.kafka.source.KafkaSource;
import org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.connectors.redis.RedisSink;
import org.apache.flink.streaming.connectors.redis.common.config.FlinkJedisPoolConfig;
import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommand;
import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommandDescription;
import org.apache.flink.streaming.connectors.redis.common.mapper.RedisMapper;
import java.io.Serializable;
import java.time.Duration;
import java.util.HashSet;
import java.util.Set;

public class RealTimeFeatureEngineeringJob {

    public static void main(String[] args) throws Exception {
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        final ParameterTool params = ParameterTool.fromArgs(args);

        // Kafka Source
        KafkaSource<String> source = KafkaSource.<String>builder()
                .setBootstrapServers(params.get("kafka.bootstrap.servers", "localhost:9092"))
                .setTopics(params.get("kafka.topic", "user-actions"))
                .setGroupId(params.get("flink.group.id", "flink-feature-consumer"))
                .setStartingOffsets(OffsetsInitializer.latest())
                .setValueOnlyDeserializer(new SimpleStringSchema())
                .build();

        // DataStream from Kafka
        DataStream<String> kafkaStream = env.fromSource(source, WatermarkStrategy.forMonotonousTimestamps(), "Kafka Source");
        
        // 1. Deserialization and Watermark Assignment
        DataStream<UserActionEvent> eventStream = kafkaStream
                .map(new JsonDeserializer())
                .assignTimestampsAndWatermarks(
                        WatermarkStrategy.<UserActionEvent>forBoundedOutOfOrderness(Duration.ofSeconds(5))
                                .withTimestampAssigner((event, timestamp) -> event.getTimestamp())
                );

        // 2. Key by userId and apply window
        DataStream<UserFeature> featureStream = eventStream
                .keyBy(UserActionEvent::getUserId)
                .window(TumblingEventTimeWindows.of(Time.minutes(5)))
                // AggregateFunction 相比 ProcessWindowFunction 更高效,因为它在窗口触发前就进行增量聚合
                .aggregate(new UserActionAggregator());

        // 3. Sink to Redis
        FlinkJedisPoolConfig jedisPoolConfig = new FlinkJedisPoolConfig.Builder()
                .setHost(params.get("redis.host", "localhost"))
                .setPort(params.getInt("redis.port", 6379))
                .build();

        featureStream.addSink(new RedisSink<>(jedisPoolConfig, new UserFeatureRedisMapper()));

        env.execute("Real-time User Feature Engineering");
    }

    // --- Helper Classes ---

    public static class JsonDeserializer implements org.apache.flink.api.common.functions.MapFunction<String, UserActionEvent> {
        private static final ObjectMapper objectMapper = new ObjectMapper();
        @Override
        public UserActionEvent map(String value) throws Exception {
            try {
                return objectMapper.readValue(value, UserActionEvent.class);
            } catch (Exception e) {
                // 在生产环境中,这里应该将解析失败的消息发送到死信队列
                return new UserActionEvent("PARSE_ERROR", "", "", 0L);
            }
        }
    }
    
    // Flink 的 AggregateFunction<IN, ACC, OUT>
    // IN: UserActionEvent, ACC: 累加器 (RequestStats), OUT: UserFeature
    public static class UserActionAggregator implements AggregateFunction<UserActionEvent, RequestStats, UserFeature> {
        @Override
        public RequestStats createAccumulator() {
            return new RequestStats();
        }

        @Override
        public RequestStats add(UserActionEvent value, RequestStats accumulator) {
            accumulator.totalRequests++;
            accumulator.distinctPaths.add(value.getPath());
            accumulator.userId = value.getUserId(); // userId 在这个 key 分组下是相同的
            accumulator.windowEndTimestamp = System.currentTimeMillis() + (5 * 60 * 1000); // 估算一个过期时间
            return accumulator;
        }

        @Override
        public UserFeature getResult(RequestStats accumulator) {
            return new UserFeature(
                    accumulator.userId,
                    accumulator.totalRequests,
                    accumulator.distinctPaths.size(),
                    accumulator.windowEndTimestamp
            );
        }

        @Override
        public RequestStats merge(RequestStats a, RequestStats b) {
            a.totalRequests += b.totalRequests;
            a.distinctPaths.addAll(b.distinctPaths);
            return a;
        }
    }

    // Redis Sink Mapper
    public static class UserFeatureRedisMapper implements RedisMapper<UserFeature> {
        private final ObjectMapper objectMapper = new ObjectMapper();

        @Override
        public RedisCommandDescription getCommandDescription() {
            // 使用 HSET, key 是 "user_features", field 是 userId, value 是特征JSON
            return new RedisCommandDescription(RedisCommand.HSET, "user_features");
        }

        @Override
        public String getKeyFromData(UserFeature data) {
            return data.getUserId();
        }

        @Override
        public String getValueFromData(UserFeature data) {
            try {
                return objectMapper.writeValueAsString(data);
            } catch (Exception e) {
                // 序列化失败处理
                return "{}";
            }
        }
    }

    // --- Data Structures ---
    
    @lombok.Data @lombok.AllArgsConstructor @lombok.NoArgsConstructor
    public static class UserActionEvent implements Serializable {
        private String userId;
        private String path;
        private String method;
        private long timestamp;
    }
    
    @lombok.Data @lombok.AllArgsConstructor @lombok.NoArgsConstructor
    public static class UserFeature implements Serializable {
        private String userId;
        private long requestCountLast5Min;
        private int distinctPathsLast5Min;
        private long expiresAt;
    }

    public static class RequestStats {
        String userId;
        long totalRequests = 0L;
        Set<String> distinctPaths = new HashSet<>();
        long windowEndTimestamp;
    }
}

代码要点分析:

  • 水印与事件时间: 使用 forBoundedOutOfOrderness 策略处理网络延迟等导致的事件乱序问题,这是保证时间窗口计算准确性的关键。
  • AggregateFunction: 相比于将整个窗口数据缓存起来再处理的 ProcessWindowFunctionAggregateFunction 实现了增量聚合,极大地节省了内存和计算资源,对于简单的计数、求和场景是最佳选择。
  • 状态后端: 在生产环境中,需要为 Flink 作业配置状态后端,例如 RocksDBStateBackend,以支持大状态和故障恢复。
  • Redis Sink: 我们选择 HSET 命令,将所有用户的特征存储在同一个 Redis Hash 结构中,keyuser_featuresfielduserId。这种方式比为每个用户创建一个独立的 key 更节省内存,也便于管理。

第三步:后端服务融合特征与 LangChain

最后一步是改造我们的 LLM 后端服务。它在接收到请求后,不再是直接调用 LangChain,而是增加一个“特征注入”步骤。

LlmOrchestrationService.java

package com.example.llmservice.service;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import java.util.Optional;

@Slf4j
@Service
public class LlmOrchestrationService {

    private final StringRedisTemplate redisTemplate;
    private final RestTemplate restTemplate; // 用于调用 Python LangChain 服务
    private final ObjectMapper objectMapper;
    private static final String REDIS_HASH_KEY = "user_features";
    private static final String LANGCHAIN_SERVICE_URL = "http://langchain-service/generate";

    public LlmOrchestrationService(StringRedisTemplate redisTemplate, RestTemplate restTemplate, ObjectMapper objectMapper) {
        this.redisTemplate = redisTemplate;
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
    }

    public String generateResponse(String userId, String query) {
        // 1. 获取实时特征
        Optional<UserFeature> userFeature = getRealTimeFeatures(userId);

        // 2. 构建动态 Prompt
        String prompt = buildDynamicPrompt(query, userFeature);
        log.info("Constructed prompt for user {}: {}", userId, prompt);

        // 3. 调用 LangChain 服务
        // LangChainRequest request = new LangChainRequest(prompt);
        // return restTemplate.postForObject(LANGCHAIN_SERVICE_URL, request, String.class);
        return "Simulated LLM response for prompt: " + prompt; // 模拟调用
    }

    private Optional<UserFeature> getRealTimeFeatures(String userId) {
        try {
            String featureJson = (String) redisTemplate.opsForHash().get(REDIS_HASH_KEY, userId);
            if (featureJson != null) {
                return Optional.of(objectMapper.readValue(featureJson, UserFeature.class));
            }
        } catch (Exception e) {
            log.error("Failed to retrieve or parse features for user {} from Redis", userId, e);
        }
        return Optional.empty();
    }

    private String buildDynamicPrompt(String query, Optional<UserFeature> feature) {
        StringBuilder contextBuilder = new StringBuilder();
        contextBuilder.append("User query: \"").append(query).append("\"\n");

        feature.ifPresent(f -> {
            contextBuilder.append("\n--- Real-time User Context (last 5 minutes) ---\n");
            contextBuilder.append(String.format("- Total interactions: %d\n", f.getRequestCountLast5Min()));
            contextBuilder.append(String.format("- Number of distinct features explored: %d\n", f.getDistinctPathsLast5Min()));

            if (f.getRequestCountLast5Min() > 10) {
                 contextBuilder.append("- User is highly active right now.\n");
            }
            contextBuilder.append("--------------------------------------------\n");
        });
        
        contextBuilder.append("\nBased on the query and the real-time context, provide a helpful and personalized response.");
        return contextBuilder.toString();
    }

    // DTO for Redis data
    @lombok.Data
    public static class UserFeature {
        private long requestCountLast5Min;
        private int distinctPathsLast5Min;
    }
}

通过这个服务,我们成功将 Flink 计算出的实时特征注入到了 Prompt 中。现在,LLM 接收到的不再仅仅是用户的孤立问题,还有一个描述其当前行为状态的上下文摘要。这使得模型可以生成更具相关性和个性化的回答。

方案的局限性与未来展望

当前这套架构有效地解决了实时上下文注入的问题,但并非没有局限性。首先,特征维度相对单一,仅包含了请求计数等。未来可以引入更复杂的特征,例如分析请求 Body 内容(需要注意隐私和性能开销),或者通过 Flink 的 ProcessFunction 结合状态实现更复杂的序列模式匹配(例如,识别“用户连续三次搜索A但未点击”的行为模式)。

其次,整个链路的延迟依赖于各个组件的性能,特别是 Flink 窗口的长度和触发频率。5分钟的窗口对于某些场景可能还是太长,缩短窗口会增加计算和存储的压力,这需要根据业务需求进行权衡。一个可行的优化路径是采用混合窗口策略,例如同时计算1分钟、5分钟、15分钟的特征,为不同场景提供不同时间粒度的上下文。

最后,目前的特征还只是行为统计。更进一步的探索是将这些实时特征与用户画像、历史订单等离线数据在 LangChain 服务层进行融合,构建一个真正立体、动态的用户上下文,这将是驱动下一代智能应用的核心。


  目录