001// License: GPL. For details, see LICENSE file.
002package org.openstreetmap.josm.tools;
003
004import java.awt.Dimension;
005import java.awt.geom.Point2D;
006import java.awt.geom.Rectangle2D;
007import java.awt.image.BufferedImage;
008import java.util.HashMap;
009import java.util.HashSet;
010import java.util.Map;
011import java.util.Set;
012
013/**
014 * Image warping algorithm.
015 *
016 * Deforms an image geometrically according to a given transformation formula.
017 * @since 11858
018 */
019public final class ImageWarp {
020
021    private ImageWarp() {
022        // Hide default constructor
023    }
024
025    /**
026     * Transformation that translates the pixel coordinates.
027     */
028    public interface PointTransform {
029        /**
030         * Translates pixel coordinates.
031         * @param pt pixel coordinates
032         * @return transformed pixel coordinates
033         */
034        Point2D transform(Point2D pt);
035    }
036
037    /**
038     * Wrapper that optimizes a given {@link ImageWarp.PointTransform}.
039     *
040     * It does so by spanning a grid with certain step size. It will invoke the
041     * potentially expensive master transform only at those grid points and use
042     * bilinear interpolation to approximate transformed values in between.
043     * <p>
044     * For memory optimization, this class assumes that rows are more or less scanned
045     * one-by-one as is done in {@link ImageWarp#warp}. I.e. this transform is <em>not</em>
046     * random access in the y coordinate.
047     */
048    public static class GridTransform implements ImageWarp.PointTransform {
049
050        private final double stride;
051        private final ImageWarp.PointTransform trfm;
052
053        private final Map<Integer, Map<Integer, Point2D>> cache;
054
055        private final boolean consistencyTest;
056        private final Set<Integer> deletedRows;
057
058        /**
059         * Create a new GridTransform.
060         * @param trfm the master transform, that needs to be optimized
061         * @param stride step size
062         */
063        public GridTransform(ImageWarp.PointTransform trfm, double stride) {
064            this.trfm = trfm;
065            this.stride = stride;
066            this.cache = new HashMap<>();
067            this.consistencyTest = Logging.isDebugEnabled();
068            if (consistencyTest) {
069                deletedRows = new HashSet<>();
070            } else {
071                deletedRows = null;
072            }
073        }
074
075        @Override
076        public Point2D transform(Point2D pt) {
077            int xIdx = (int) Math.floor(pt.getX() / stride);
078            int yIdx = (int) Math.floor(pt.getY() / stride);
079            double dx = pt.getX() / stride - xIdx;
080            double dy = pt.getY() / stride - yIdx;
081            Point2D value00 = getValue(xIdx, yIdx);
082            Point2D value01 = getValue(xIdx, yIdx + 1);
083            Point2D value10 = getValue(xIdx + 1, yIdx);
084            Point2D value11 = getValue(xIdx + 1, yIdx + 1);
085            double valueX = (value00.getX() * (1-dx) + value10.getX() * dx) * (1-dy) +
086                    (value01.getX() * (1-dx) + value11.getX() * dx) * dy;
087            double valueY = (value00.getY() * (1-dx) + value10.getY() * dx) * (1-dy) +
088                    (value01.getY() * (1-dx) + value11.getY() * dx) * dy;
089            return new Point2D.Double(valueX, valueY);
090        }
091
092        private Point2D getValue(int xIdx, int yIdx) {
093            return getRow(yIdx).computeIfAbsent(xIdx, k -> trfm.transform(new Point2D.Double(xIdx * stride, yIdx * stride)));
094        }
095
096        private Map<Integer, Point2D> getRow(int yIdx) {
097            cleanUp(yIdx - 3);
098            Map<Integer, Point2D> row = cache.get(yIdx);
099            if (row == null) {
100                row = new HashMap<>();
101                cache.put(yIdx, row);
102                if (consistencyTest) {
103                    // should not create a row that has been deleted before
104                    if (deletedRows.contains(yIdx)) throw new AssertionError();
105                    // only ever cache 3 rows at once
106                    if (cache.size() > 3) throw new AssertionError();
107                }
108            }
109            return row;
110        }
111
112        // remove rows from cache that will no longer be used
113        private void cleanUp(int yIdx) {
114            Map<Integer, Point2D> del = cache.remove(yIdx);
115            if (consistencyTest && del != null) {
116                // should delete each row only once
117                if (deletedRows.contains(yIdx)) throw new AssertionError();
118                deletedRows.add(yIdx);
119            }
120        }
121    }
122
123    /**
124     * Interpolation method.
125     */
126    public enum Interpolation {
127        /**
128         * Nearest neighbor.
129         *
130         * Simplest possible method. Faster, but not very good quality.
131         */
132        NEAREST_NEIGHBOR,
133
134        /**
135         * Bilinear.
136         *
137         * Decent quality.
138         */
139        BILINEAR;
140    }
141
142    /**
143     * Warp an image.
144     * @param srcImg the original image
145     * @param targetDim dimension of the target image
146     * @param invTransform inverse transformation (translates pixel coordinates
147     * of the target image to pixel coordinates of the original image)
148     * @param interpolation the interpolation method
149     * @return the warped image
150     */
151    public static BufferedImage warp(BufferedImage srcImg, Dimension targetDim, PointTransform invTransform, Interpolation interpolation) {
152        BufferedImage imgTarget = new BufferedImage(targetDim.width, targetDim.height, BufferedImage.TYPE_INT_ARGB);
153        Rectangle2D srcRect = new Rectangle2D.Double(0, 0, srcImg.getWidth(), srcImg.getHeight());
154        for (int j = 0; j < imgTarget.getHeight(); j++) {
155            for (int i = 0; i < imgTarget.getWidth(); i++) {
156                Point2D srcCoord = invTransform.transform(new Point2D.Double(i, j));
157                if (srcRect.contains(srcCoord)) {
158                    int rgba;
159                    switch (interpolation) {
160                        case NEAREST_NEIGHBOR:
161                            rgba = getColor((int) Math.round(srcCoord.getX()), (int) Math.round(srcCoord.getY()), srcImg);
162                            break;
163                        case BILINEAR:
164                            int x0 = (int) Math.floor(srcCoord.getX());
165                            double dx = srcCoord.getX() - x0;
166                            int y0 = (int) Math.floor(srcCoord.getY());
167                            double dy = srcCoord.getY() - y0;
168                            int c00 = getColor(x0, y0, srcImg);
169                            int c01 = getColor(x0, y0 + 1, srcImg);
170                            int c10 = getColor(x0 + 1, y0, srcImg);
171                            int c11 = getColor(x0 + 1, y0 + 1, srcImg);
172                            rgba = 0;
173                            // loop over color components: blue, green, red, alpha
174                            for (int ch = 0; ch <= 3; ch++) {
175                                int shift = 8 * ch;
176                                int chVal = (int) Math.round(
177                                    (((c00 >> shift) & 0xff) * (1-dx) + ((c10 >> shift) & 0xff) * dx) * (1-dy) +
178                                    (((c01 >> shift) & 0xff) * (1-dx) + ((c11 >> shift) & 0xff) * dx) * dy);
179                                rgba |= chVal << shift;
180                            }
181                            break;
182                        default:
183                            throw new AssertionError();
184                    }
185                    imgTarget.setRGB(i, j, rgba);
186                }
187            }
188        }
189        return imgTarget;
190    }
191
192    private static int getColor(int x, int y, BufferedImage img) {
193        // border strategy: continue with the color of the outermost pixel,
194        return img.getRGB(
195                Utils.clamp(x, 0, img.getWidth() - 1),
196                Utils.clamp(y, 0, img.getHeight() - 1));
197    }
198}